Unverified Commit bce51cbd authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] check the number of columns in subgraph.copy_from_parent (#903)

* check the number of columns.

* add test.
parent 3a0bbb3e
......@@ -116,10 +116,10 @@ class DGLSubGraph(DGLGraph):
All old features will be removed.
"""
if self._parent._node_frame.num_rows != 0:
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0:
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._get_parent_eid()]))
......
......@@ -4,7 +4,7 @@ import backend as F
D = 5
def generate_graph(grad=False):
def generate_graph(grad=False, add_data=True):
g = DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
......@@ -13,15 +13,25 @@ def generate_graph(grad=False):
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata['h'] = ncol
g.edata['l'] = ecol
if add_data:
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata['h'] = ncol
g.edata['l'] = ecol
return g
def test_basics1():
# Test when the graph has no node data and edge data.
g = generate_graph(add_data=False)
eid = [0, 2, 3, 6, 7, 9]
sg = g.edge_subgraph(eid)
sg.copy_from_parent()
sg.ndata['h'] = F.arange(0, sg.number_of_nodes())
sg.edata['h'] = F.arange(0, sg.number_of_edges())
def test_basics():
g = generate_graph()
h = g.ndata['h']
......@@ -96,4 +106,5 @@ def test_merge():
if __name__ == '__main__':
test_basics()
test_basics1()
#test_merge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment