Unverified Commit a6d5a0cb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fix #717 (#946)

* [BUG] Fix #717

* fix mxnet test
parent 421b05e7
...@@ -46,6 +46,7 @@ class DGLSubGraph(DGLGraph): ...@@ -46,6 +46,7 @@ class DGLSubGraph(DGLGraph):
self._parent = parent self._parent = parent
self._parent_nid = sgi.induced_nodes self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges self._parent_eid = sgi.induced_edges
self._subgraph_index = sgi
# override APIs # override APIs
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
...@@ -136,4 +137,5 @@ class DGLSubGraph(DGLGraph): ...@@ -136,4 +137,5 @@ class DGLSubGraph(DGLGraph):
tensor tensor
The node ID array in the subgraph. The node ID array in the subgraph.
""" """
return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids)).tousertensor() v = map_to_subgraph_nid(self._subgraph_index, utils.toindex(parent_vids))
return v.tousertensor()
...@@ -76,6 +76,14 @@ def test_basics(): ...@@ -76,6 +76,14 @@ def test_basics():
sg.ndata['h'] = F.zeros((6, D)) sg.ndata['h'] = F.zeros((6, D))
assert F.allclose(h, g.ndata['h']) assert F.allclose(h, g.ndata['h'])
def test_map_to_subgraph():
g = DGLGraph()
g.add_nodes(10)
g.add_edges(F.arange(0, 9), F.arange(1, 10))
h = g.subgraph([0, 1, 2, 5, 8])
v = h.map_to_subgraph_nid([0, 8, 2])
assert np.array_equal(F.asnumpy(v), np.array([0, 4, 2]))
def test_merge(): def test_merge():
# FIXME: current impl cannot handle this case!!! # FIXME: current impl cannot handle this case!!!
# comment out for now to test CI # comment out for now to test CI
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment