"examples/cpp/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "508bc1dc8d7cff5c1383068d6601ff669f69111d"
Unverified Commit 0ca52bfc authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Hetero] Create heterograph from dictionary of edge lists (#882)

* [Hetero] Create heterograph from dictionary of edge lists

* rename

* address comments
parent 1c91f460
...@@ -15,6 +15,7 @@ __all__ = [ ...@@ -15,6 +15,7 @@ __all__ = [
'graph', 'graph',
'bipartite', 'bipartite',
'hetero_from_relations', 'hetero_from_relations',
'heterograph',
'to_hetero', 'to_hetero',
'to_homo', 'to_homo',
'to_networkx', 'to_networkx',
...@@ -264,6 +265,83 @@ def hetero_from_relations(rel_graphs): ...@@ -264,6 +265,83 @@ def hetero_from_relations(rel_graphs):
retg._edge_frames[i].update(rgrh._edge_frames[0]) retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg return retg
def heterograph(data_dict, num_nodes_dict=None):
"""Create a heterogeneous graph from a dictionary between edge types and edge lists.
Parameters
----------
data_dict : dict
The dictionary between edge types and edge list data.
The edge types are specified as a triplet of (source node type name, edge type
name, destination node type name).
The edge list data can be anything acceptable by :func:`dgl.graph` or
:func:`dgl.bipartite`, or objects returned by the two functions themselves.
num_nodes_dict : dict[str, int]
The number of nodes for each node type.
By default DGL infers the number of nodes for each node type from ``data_dict``
by taking the maximum node ID plus one for each node type.
Returns
-------
DGLHeteroGraph
Examples
--------
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... })
"""
rel_graphs = []
# infer number of nodes for each node type
if num_nodes_dict is None:
num_nodes_dict = defaultdict(int)
for (srctype, etype, dsttype), data in data_dict.items():
if isinstance(data, tuple):
nsrc = max(data[0]) + 1
ndst = max(data[1]) + 1
elif isinstance(data, list):
src, dst = zip(*data)
nsrc = max(src) + 1
ndst = max(dst) + 1
elif isinstance(data, sp.sparse.spmatrix):
nsrc = data.shape[0]
ndst = data.shape[1]
elif isinstance(data, nx.Graph):
if srctype == dsttype:
nsrc = ndst = data.number_of_nodes()
else:
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc
elif isinstance(data, DGLHeteroGraph):
# Do nothing; handled in the next loop
pass
else:
raise DGLError('Unsupported graph data type %s for %s' % (
type(data), (srctype, etype, dsttype)))
if srctype == dsttype:
ndst = nsrc = max(nsrc, ndst)
num_nodes_dict[srctype] = max(num_nodes_dict[srctype], nsrc)
num_nodes_dict[dsttype] = max(num_nodes_dict[dsttype], ndst)
for (srctype, etype, dsttype), data in data_dict.items():
if isinstance(data, DGLHeteroGraph):
rel_graphs.append(data)
elif srctype == dsttype:
rel_graphs.append(graph(data, srctype, etype, card=num_nodes_dict[srctype]))
else:
rel_graphs.append(bipartite(
data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype])))
return hetero_from_relations(rel_graphs)
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None): def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
"""Convert the given graph to a heterogeneous graph. """Convert the given graph to a heterogeneous graph.
......
...@@ -77,6 +77,14 @@ class DGLHeteroGraph(object): ...@@ -77,6 +77,14 @@ class DGLHeteroGraph(object):
>>> dev_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game') >>> dev_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, dev_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g, dev_g])
Or equivalently
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... })
:func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of :func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
data types including: edge list, edge tuples, networkx graph and scipy sparse matrix. data types including: edge list, edge tuples, networkx graph and scipy sparse matrix.
Click the function name for more details. Click the function name for more details.
......
...@@ -43,14 +43,32 @@ def create_test_heterograph1(): ...@@ -43,14 +43,32 @@ def create_test_heterograph1():
g0.edata[dgl.ETYPE] = etypes g0.edata[dgl.ETYPE] = etypes
return dgl.to_hetero(g0, ['user', 'game', 'developer'], ['follows', 'plays', 'wishes', 'develops']) return dgl.to_hetero(g0, ['user', 'game', 'developer'], ['follows', 'plays', 'wishes', 'develops'])
def create_test_heterograph2():
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): plays_spmat,
('user', 'wishes', 'game'): wishes_nx,
('developer', 'develops', 'game'): develops_g,
})
return g
def get_redfn(name): def get_redfn(name):
return getattr(F, name) return getattr(F, name)
def test_create(): def test_create():
g0 = create_test_heterograph() g0 = create_test_heterograph()
g1 = create_test_heterograph1() g1 = create_test_heterograph1()
assert set(g0.ntypes) == set(g1.ntypes) g2 = create_test_heterograph2()
assert set(g0.canonical_etypes) == set(g1.canonical_etypes) assert set(g0.ntypes) == set(g1.ntypes) == set(g2.ntypes)
assert set(g0.canonical_etypes) == set(g1.canonical_etypes) == set(g2.canonical_etypes)
# create from nx complete bipartite graph # create from nx complete bipartite graph
nxg = nx.complete_bipartite_graph(3, 4) nxg = nx.complete_bipartite_graph(3, 4)
...@@ -65,6 +83,16 @@ def test_create(): ...@@ -65,6 +83,16 @@ def test_create():
assert g.number_of_nodes() == 4 assert g.number_of_nodes() == 4
assert g.number_of_edges() == 3 assert g.number_of_edges() == 3
# test inferring number of nodes for heterograph
g = dgl.heterograph({
('l0', 'e0', 'l1'): [(0, 1), (0, 2)],
('l0', 'e1', 'l2'): [(2, 2)],
('l2', 'e2', 'l2'): [(1, 1), (3, 3)],
})
assert g.number_of_nodes('l0') == 3
assert g.number_of_nodes('l1') == 3
assert g.number_of_nodes('l2') == 4
def test_query(): def test_query():
g = create_test_heterograph() g = create_test_heterograph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment