Unverified Commit 1af0d806 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Fix the behavior of dgl.batch when batch size is 1 (#1483)



* fix

* upd

* upd

* upd

* upd

* clone

* upd
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ef7e4750
...@@ -4144,7 +4144,13 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -4144,7 +4144,13 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
unbatch unbatch
""" """
if len(graph_list) == 1: if len(graph_list) == 1:
return graph_list[0] # Need to deepcopy the node/edge frame of original graph.
graph = graph_list[0]
return DGLGraph(graph_data=graph._graph,
node_frame=graph._node_frame.deepclone(),
edge_frame=graph._edge_frame.deepclone(),
batch_num_nodes=graph.batch_num_nodes,
batch_num_edges=graph.batch_num_edges)
def _init_attrs(attrs, mode): def _init_attrs(attrs, mode):
"""Collect attributes of given mode (node/edge) from graph_list. """Collect attributes of given mode (node/edge) from graph_list.
...@@ -4263,7 +4269,10 @@ def unbatch(graph): ...@@ -4263,7 +4269,10 @@ def unbatch(graph):
batch batch
""" """
if graph.batch_size == 1: if graph.batch_size == 1:
return [graph] # Like dgl.batch, unbatch also deep copies data frame.
return [DGLGraph(graph_data=graph._graph,
node_frame=graph._node_frame.deepclone(),
edge_frame=graph._edge_frame.deepclone())]
bsize = graph.batch_size bsize = graph.batch_size
bnn = graph.batch_num_nodes bnn = graph.batch_num_nodes
......
...@@ -102,6 +102,43 @@ def test_batch_unbatch1(): ...@@ -102,6 +102,43 @@ def test_batch_unbatch1():
assert F.allclose(t2.ndata['h'], rs2.ndata['h']) assert F.allclose(t2.ndata['h'], rs2.ndata['h'])
assert F.allclose(t2.edata['h'], rs2.edata['h']) assert F.allclose(t2.edata['h'], rs2.edata['h'])
def test_batch_unbatch_frame():
"""Test module of node/edge frames of batched/unbatched DGLGraphs.
Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
"""
t1 = tree1()
t2 = tree2()
N1 = t1.number_of_nodes()
E1 = t1.number_of_edges()
N2 = t2.number_of_nodes()
E2 = t2.number_of_edges()
D = 10
t1.ndata['h'] = F.randn((N1, D))
t1.edata['h'] = F.randn((E1, D))
t2.ndata['h'] = F.randn((N2, D))
t2.edata['h'] = F.randn((E2, D))
if F.backend_name != 'tensorflow': # tf's tensor is immutable
b1 = dgl.batch([t1, t2])
b2 = dgl.batch([t2])
b1.ndata['h'][:N1] = F.zeros((N1, D))
b1.edata['h'][:E1] = F.zeros((E1, D))
b2.ndata['h'][:N2] = F.zeros((N2, D))
b2.edata['h'][:E2] = F.zeros((E2, D))
assert not F.allclose(t1.ndata['h'], F.zeros((N1, D)))
assert not F.allclose(t1.edata['h'], F.zeros((E1, D)))
assert not F.allclose(t2.ndata['h'], F.zeros((N2, D)))
assert not F.allclose(t2.edata['h'], F.zeros((E2, D)))
g1, g2 = dgl.unbatch(b1)
_g2, = dgl.unbatch(b2)
assert F.allclose(g1.ndata['h'], F.zeros((N1, D)))
assert F.allclose(g1.edata['h'], F.zeros((E1, D)))
assert F.allclose(g2.ndata['h'], t2.ndata['h'])
assert F.allclose(g2.edata['h'], t2.edata['h'])
assert F.allclose(_g2.ndata['h'], F.zeros((N2, D)))
assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))
def test_batch_unbatch2(): def test_batch_unbatch2():
# test setting/getting features after batch # test setting/getting features after batch
a = dgl.DGLGraph() a = dgl.DGLGraph()
...@@ -206,6 +243,7 @@ def test_batch_no_edge(): ...@@ -206,6 +243,7 @@ def test_batch_no_edge():
if __name__ == '__main__': if __name__ == '__main__':
test_batch_unbatch() test_batch_unbatch()
test_batch_unbatch1() test_batch_unbatch1()
test_batch_unbatch_frame()
#test_batch_unbatch2() #test_batch_unbatch2()
#test_batched_edge_ordering() #test_batched_edge_ordering()
#test_batch_send_then_recv() #test_batch_send_then_recv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment