Unverified Commit 421763fb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix #1641 (#1678)

* fix #1641

* lint
parent fc7cd275
...@@ -376,10 +376,12 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -376,10 +376,12 @@ class UnitGraph::COO : public BaseHeteroGraph {
} else { } else {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]); IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]); IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(0), NumBits(), Context())); subg.induced_vertices.emplace_back(
subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(1), NumBits(), Context())); aten::Range(0, NumVertices(SrcType()), NumBits(), Context()));
subg.induced_vertices.emplace_back(
aten::Range(0, NumVertices(DstType()), NumBits(), Context()));
subg.graph = std::make_shared<COO>( subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(0), NumVertices(1), new_src, new_dst); meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
subg.induced_edges = eids; subg.induced_edges = eids;
} }
return subg; return subg;
......
...@@ -933,32 +933,52 @@ def test_subgraph(index_dtype): ...@@ -933,32 +933,52 @@ def test_subgraph(index_dtype):
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]}) sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2) _check_subgraph(g, sg2)
def _check_subgraph_single_ntype(g, sg): def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):
assert sg.ntypes == g.ntypes assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64)) if not preserve_nodes:
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
else:
for ntype in sg.ntypes:
assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)
assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]), assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]),
F.tensor([1], F.int64)) F.tensor([1], F.int64))
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
if not preserve_nodes:
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2]) assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2])
def _check_subgraph_single_etype(g, sg): def _check_subgraph_single_etype(g, sg, preserve_nodes=False):
assert sg.ntypes == g.ntypes assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0, 1], F.int64)) if not preserve_nodes:
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]), assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0], F.int64)) F.tensor([0, 1], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
else:
for ntype in sg.ntypes:
assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)
assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]), assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]),
F.tensor([0, 1], F.int64)) F.tensor([0, 1], F.int64))
sg1_graph = g_graph.subgraph([1, 2]) sg1_graph = g_graph.subgraph([1, 2])
_check_subgraph_single_ntype(g_graph, sg1_graph) _check_subgraph_single_ntype(g_graph, sg1_graph)
sg1_graph = g_graph.edge_subgraph([1])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg1_graph = g_graph.edge_subgraph([1], preserve_nodes=True)
_check_subgraph_single_ntype(g_graph, sg1_graph, True)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1]) sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
_check_subgraph_single_etype(g_bipartite, sg2_bipartite) _check_subgraph_single_etype(g_bipartite, sg2_bipartite)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1], preserve_nodes=True)
_check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)
def _check_typed_subgraph1(g, sg): def _check_typed_subgraph1(g, sg):
assert set(sg.ntypes) == {'user', 'game'} assert set(sg.ntypes) == {'user', 'game'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment