Unverified Commit bce51cbd authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] check the number of columns in subgraph.copy_from_parent (#903)

* check the number of columns.

* add test.
parent 3a0bbb3e
...@@ -116,10 +116,10 @@ class DGLSubGraph(DGLGraph): ...@@ -116,10 +116,10 @@ class DGLSubGraph(DGLGraph):
All old features will be removed. All old features will be removed.
""" """
if self._parent._node_frame.num_rows != 0: if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame( self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid])) self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0: if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame( self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()])) self._parent._edge_frame[self._get_parent_eid()]))
......
...@@ -4,7 +4,7 @@ import backend as F ...@@ -4,7 +4,7 @@ import backend as F
D = 5 D = 5
def generate_graph(grad=False): def generate_graph(grad=False, add_data=True):
g = DGLGraph() g = DGLGraph()
g.add_nodes(10) g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
...@@ -13,6 +13,7 @@ def generate_graph(grad=False): ...@@ -13,6 +13,7 @@ def generate_graph(grad=False):
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
if add_data:
ncol = F.randn((10, D)) ncol = F.randn((10, D))
ecol = F.randn((17, D)) ecol = F.randn((17, D))
if grad: if grad:
...@@ -22,6 +23,15 @@ def generate_graph(grad=False): ...@@ -22,6 +23,15 @@ def generate_graph(grad=False):
g.edata['l'] = ecol g.edata['l'] = ecol
return g return g
def test_basics1():
# Test when the graph has no node data and edge data.
g = generate_graph(add_data=False)
eid = [0, 2, 3, 6, 7, 9]
sg = g.edge_subgraph(eid)
sg.copy_from_parent()
sg.ndata['h'] = F.arange(0, sg.number_of_nodes())
sg.edata['h'] = F.arange(0, sg.number_of_edges())
def test_basics(): def test_basics():
g = generate_graph() g = generate_graph()
h = g.ndata['h'] h = g.ndata['h']
...@@ -96,4 +106,5 @@ def test_merge(): ...@@ -96,4 +106,5 @@ def test_merge():
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
test_basics1()
#test_merge() #test_merge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment