"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3b48620f5ea204901b6b0e8f089a9ddd5022e0ee"
Commit 4dfe7547 authored by Minjie Wang's avatar Minjie Wang Committed by Zihao Ye
Browse files

Fix num_rows bug in batched_graph (#169)

parent 23e2e83b
...@@ -31,12 +31,18 @@ class BatchedDGLGraph(DGLGraph): ...@@ -31,12 +31,18 @@ class BatchedDGLGraph(DGLGraph):
# create batched graph index # create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list]) batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames # create batched node and edge frames
if len(node_attrs) == 0:
batched_node_frame = FrameRef(Frame(num_rows=batched_index.number_of_nodes()))
else:
# NOTE: following code will materialize the columns of the input graphs. # NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0) if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs} for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols)) batched_node_frame = FrameRef(Frame(cols))
if len(edge_attrs) == 0:
batched_edge_frame = FrameRef(Frame(num_rows=batched_index.number_of_edges()))
else:
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0) if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs} for key in edge_attrs}
......
...@@ -76,6 +76,20 @@ def test_batch_unbatch1(): ...@@ -76,6 +76,20 @@ def test_batch_unbatch1():
assert U.allclose(t2.ndata['h'], s3.ndata['h']) assert U.allclose(t2.ndata['h'], s3.ndata['h'])
assert U.allclose(t2.edata['h'], s3.edata['h']) assert U.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_unbatch2():
# test setting/getting features after batch
a = dgl.DGLGraph()
a.add_nodes(4)
a.add_edges(0, [1, 2, 3])
b = dgl.DGLGraph()
b.add_nodes(3)
b.add_edges(0, [1, 2])
c = dgl.batch([a, b])
c.ndata['h'] = th.ones(7, 1)
c.edata['w'] = th.ones(5, 1)
assert U.allclose(c.ndata['h'], th.ones(7, 1))
assert U.allclose(c.edata['w'], th.ones(5, 1))
def test_batch_send_then_recv(): def test_batch_send_then_recv():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
...@@ -166,6 +180,7 @@ def test_batch_no_edge(): ...@@ -166,6 +180,7 @@ def test_batch_no_edge():
if __name__ == '__main__': if __name__ == '__main__':
test_batch_unbatch() test_batch_unbatch()
test_batch_unbatch1() test_batch_unbatch1()
test_batch_unbatch2()
test_batched_edge_ordering() test_batched_edge_ordering()
test_batch_send_then_recv() test_batch_send_then_recv()
test_batch_send_and_recv() test_batch_send_and_recv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment