"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9c8eca702c2fa811fba1ccff82a6aee6a04a2556"
Unverified Commit d2ccd218 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Allow creating graph from edges with restrict formats excluding coo. (#1926)

* upd

* upd

* upd

* upd
parent ca5a13fe
......@@ -1016,8 +1016,12 @@ def create_from_edges(u, v,
else:
num_ntypes = 2
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, formats)
if 'coo' in formats:
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, formats)
else:
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo']).formats(formats)
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
......
......@@ -82,13 +82,13 @@ def create_test_heterograph3(idtype):
wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows',
idtype=idtype, device=device).formats('coo')
idtype=idtype, device=device, formats='coo')
plays_g = dgl.bipartite([(0, 0), (1, 0), (2, 1), (1, 1)], 'user', 'plays', 'game',
idtype=idtype, device=device).formats('coo')
idtype=idtype, device=device, formats='coo')
wishes_g = dgl.bipartite([(0, 1), (2, 0)], 'user', 'wishes', 'game',
idtype=idtype, device=device).formats('coo')
idtype=idtype, device=device, formats='coo')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game',
idtype=idtype, device=device).formats('coo')
idtype=idtype, device=device, formats='coo')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
assert g.idtype == idtype
assert g.device == device
......@@ -1332,7 +1332,7 @@ def test_subgraph(idtype):
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
for fmt in ['csr', 'csc', 'coo']:
g = dgl.graph([(0, 1), (1, 2)]).formats(fmt)
g = dgl.graph([(0, 1), (1, 2)], formats=fmt)
sg = g.subgraph({g.ntypes[0]: [1, 0]})
nids = F.asnumpy(sg.ndata[dgl.NID])
assert np.array_equal(nids, np.array([1, 0]))
......@@ -1866,7 +1866,7 @@ def test_dtype_cast(idtype):
@parametrize_dtype
def test_format(idtype):
# single relation
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], idtype=idtype, device=F.ctx()).formats('coo')
g = dgl.graph([(0, 0), (1, 1), (0, 1), (2, 0)], idtype=idtype, device=F.ctx(), formats='coo')
assert g.formats()['created'] == ['coo']
assert len(g.formats()['not created']) == 0
try:
......@@ -1886,7 +1886,7 @@ def test_format(idtype):
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
('developer', 'develops', 'game'): [(0, 0), (1, 1)],
}, idtype=idtype, device=F.ctx()).formats('csr')
}, idtype=idtype, device=F.ctx(), formats='csr')
user_feat = F.randn((g['follows'].number_of_src_nodes(), 5))
g['follows'].srcdata['h'] = user_feat
assert g.formats()['created'] == ['csr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment