Commit 4dfe7547 authored by Minjie Wang's avatar Minjie Wang Committed by Zihao Ye
Browse files

Fix num_rows bug in batched_graph (#169)

parent 23e2e83b
......@@ -31,16 +31,22 @@ class BatchedDGLGraph(DGLGraph):
# create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__(
graph_data=batched_index,
......
......@@ -76,6 +76,20 @@ def test_batch_unbatch1():
assert U.allclose(t2.ndata['h'], s3.ndata['h'])
assert U.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_unbatch2():
# test setting/getting features after batch
a = dgl.DGLGraph()
a.add_nodes(4)
a.add_edges(0, [1, 2, 3])
b = dgl.DGLGraph()
b.add_nodes(3)
b.add_edges(0, [1, 2])
c = dgl.batch([a, b])
c.ndata['h'] = th.ones(7, 1)
c.edata['w'] = th.ones(5, 1)
assert U.allclose(c.ndata['h'], th.ones(7, 1))
assert U.allclose(c.edata['w'], th.ones(5, 1))
def test_batch_send_then_recv():
t1 = tree1()
t2 = tree2()
......@@ -166,6 +180,7 @@ def test_batch_no_edge():
if __name__ == '__main__':
test_batch_unbatch()
test_batch_unbatch1()
test_batch_unbatch2()
test_batched_edge_ordering()
test_batch_send_then_recv()
test_batch_send_and_recv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment