Unverified Commit 43ba94ee authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] fix etype check in DistGraph.edge_subgraph (#4322)

parent 463650a7
......@@ -1180,7 +1180,7 @@ class DistGraph:
if isinstance(edges, dict):
# TODO(zhengda) we need to directly generate subgraph of all relations with
# one invocation.
if isinstance(edges, tuple):
if isinstance(list(edges.keys())[0], tuple):
subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges}
else:
subg = {}
......
......@@ -228,6 +228,11 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
feats = F.squeeze(feats1, 1)
assert np.all(F.asnumpy(feats == eids))
# Test edge_subgraph
sg = g.edge_subgraph(eids)
assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids)
# Test init node data
new_shape = (g.number_of_nodes(), 2)
test1 = dgl.distributed.DistTensor(new_shape, F.int32)
......@@ -494,6 +499,14 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
feats = F.squeeze(feats1, 1)
assert np.all(F.asnumpy(feats == eids))
# Test edge_subgraph
sg = g.edge_subgraph({'r1': eids})
assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids)
sg = g.edge_subgraph({('n1', 'r1', 'n2'): eids})
assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids)
# Test init node data
new_shape = (g.number_of_nodes('n1'), 2)
g.nodes['n1'].data['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment