Unverified Commit 421763fb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix #1641 (#1678)

* fix #1641

* lint
parent fc7cd275
......@@ -376,10 +376,12 @@ class UnitGraph::COO : public BaseHeteroGraph {
} else {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(0), NumBits(), Context()));
subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(1), NumBits(), Context()));
subg.induced_vertices.emplace_back(
aten::Range(0, NumVertices(SrcType()), NumBits(), Context()));
subg.induced_vertices.emplace_back(
aten::Range(0, NumVertices(DstType()), NumBits(), Context()));
subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(0), NumVertices(1), new_src, new_dst);
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
subg.induced_edges = eids;
}
return subg;
......
......@@ -933,32 +933,52 @@ def test_subgraph(index_dtype):
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2)
def _check_subgraph_single_ntype(g, sg):
def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
if not preserve_nodes:
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
else:
for ntype in sg.ntypes:
assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)
assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]),
F.tensor([1], F.int64))
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
if not preserve_nodes:
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2])
def _check_subgraph_single_etype(g, sg):
def _check_subgraph_single_etype(g, sg, preserve_nodes=False):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0, 1], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
if not preserve_nodes:
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0, 1], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
else:
for ntype in sg.ntypes:
assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)
assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]),
F.tensor([0, 1], F.int64))
sg1_graph = g_graph.subgraph([1, 2])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg1_graph = g_graph.edge_subgraph([1])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg1_graph = g_graph.edge_subgraph([1], preserve_nodes=True)
_check_subgraph_single_ntype(g_graph, sg1_graph, True)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
_check_subgraph_single_etype(g_bipartite, sg2_bipartite)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1], preserve_nodes=True)
_check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)
def _check_typed_subgraph1(g, sg):
assert set(sg.ntypes) == {'user', 'game'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment