Unverified Commit d2ccd218 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Allow creating graph from edges with restrict formats excluding coo. (#1926)

* upd

* upd

* upd

* upd
parent ca5a13fe
...@@ -1016,8 +1016,12 @@ def create_from_edges(u, v, ...@@ -1016,8 +1016,12 @@ def create_from_edges(u, v,
else: else:
num_ntypes = 2 num_ntypes = 2
if 'coo' in formats:
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, formats) num_ntypes, urange, vrange, u, v, formats)
else:
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo']).formats(formats)
if utype == vtype: if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype]) return DGLHeteroGraph(hgidx, [utype], [etype])
else: else:
......
...@@ -82,13 +82,13 @@ def create_test_heterograph3(idtype): ...@@ -82,13 +82,13 @@ def create_test_heterograph3(idtype):
wishes_nx.add_edge('u2', 'g0', id=1) wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows',
idtype=idtype, device=device).formats('coo') idtype=idtype, device=device, formats='coo')
plays_g = dgl.bipartite([(0, 0), (1, 0), (2, 1), (1, 1)], 'user', 'plays', 'game', plays_g = dgl.bipartite([(0, 0), (1, 0), (2, 1), (1, 1)], 'user', 'plays', 'game',
idtype=idtype, device=device).formats('coo') idtype=idtype, device=device, formats='coo')
wishes_g = dgl.bipartite([(0, 1), (2, 0)], 'user', 'wishes', 'game', wishes_g = dgl.bipartite([(0, 1), (2, 0)], 'user', 'wishes', 'game',
idtype=idtype, device=device).formats('coo') idtype=idtype, device=device, formats='coo')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game', develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game',
idtype=idtype, device=device).formats('coo') idtype=idtype, device=device, formats='coo')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g]) g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
assert g.idtype == idtype assert g.idtype == idtype
assert g.device == device assert g.device == device
...@@ -1332,7 +1332,7 @@ def test_subgraph(idtype): ...@@ -1332,7 +1332,7 @@ def test_subgraph(idtype):
if F._default_context_str != 'gpu': if F._default_context_str != 'gpu':
# TODO(minjie): enable this later # TODO(minjie): enable this later
for fmt in ['csr', 'csc', 'coo']: for fmt in ['csr', 'csc', 'coo']:
g = dgl.graph([(0, 1), (1, 2)]).formats(fmt) g = dgl.graph([(0, 1), (1, 2)], formats=fmt)
sg = g.subgraph({g.ntypes[0]: [1, 0]}) sg = g.subgraph({g.ntypes[0]: [1, 0]})
nids = F.asnumpy(sg.ndata[dgl.NID]) nids = F.asnumpy(sg.ndata[dgl.NID])
assert np.array_equal(nids, np.array([1, 0])) assert np.array_equal(nids, np.array([1, 0]))
...@@ -1866,7 +1866,7 @@ def test_dtype_cast(idtype): ...@@ -1866,7 +1866,7 @@ def test_dtype_cast(idtype):
@parametrize_dtype @parametrize_dtype
def test_format(idtype): def test_format(idtype):
# single relation # single relation
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], idtype=idtype, device=F.ctx()).formats('coo') g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], idtype=idtype, device=F.ctx(), formats='coo')
assert g.formats()['created'] == ['coo'] assert g.formats()['created'] == ['coo']
assert len(g.formats()['not created']) == 0 assert len(g.formats()['not created']) == 0
try: try:
...@@ -1886,7 +1886,7 @@ def test_format(idtype): ...@@ -1886,7 +1886,7 @@ def test_format(idtype):
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
('developer', 'develops', 'game'): [(0, 0), (1, 1)], ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
}, idtype=idtype, device=F.ctx()).formats('csr') }, idtype=idtype, device=F.ctx(), formats='csr')
user_feat = F.randn((g['follows'].number_of_src_nodes(), 5)) user_feat = F.randn((g['follows'].number_of_src_nodes(), 5))
g['follows'].srcdata['h'] = user_feat g['follows'].srcdata['h'] = user_feat
assert g.formats()['created'] == ['csr'] assert g.formats()['created'] == ['csr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment