Unverified Commit 07787664 authored by Andrew's avatar Andrew Committed by GitHub
Browse files

[Bug fix] [Feature] added option for batching empty data (#2527)



* added option for batching empty data, fixes #2526

* added option for batching empty data, fixes #2526

* decreased line lengths

* removed trailing whitespace

* fixed wrong feature name

* now default behavior when all graphs are empty
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 4d89b54e
...@@ -11,7 +11,8 @@ from . import utils ...@@ -11,7 +11,8 @@ from . import utils
__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero'] __all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero']
def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): def batch(graphs, ndata=ALL, edata=ALL, *,
node_attrs=None, edge_attrs=None):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation. graph computation.
...@@ -191,9 +192,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): ...@@ -191,9 +192,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch node feature # Batch node feature
if ndata is not None: if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes): for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs)
frames = [ frames = [
g._node_frames[ntype_id] for g in graphs g._node_frames[ntype_id] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0] if g._graph.number_of_nodes(ntype_id) > 0 or all_empty]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype)) ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype))
...@@ -202,9 +204,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): ...@@ -202,9 +204,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch edge feature # Batch edge feature
if edata is not None: if edata is not None:
for etype_id, etype in zip(relation_ids, relations): for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs)
frames = [ frames = [
g._edge_frames[etype_id] for g in graphs g._edge_frames[etype_id] for g in graphs
if g._graph.number_of_edges(etype_id) > 0] if g._graph.number_of_edges(etype_id) > 0 or all_empty]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype)) ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype))
......
...@@ -207,6 +207,18 @@ def test_batch_no_edge(idtype): ...@@ -207,6 +207,18 @@ def test_batch_no_edge(idtype):
g3.add_nodes(1) # no edges g3.add_nodes(1) # no edges
g = dgl.batch([g1, g3, g2]) # should not throw an error g = dgl.batch([g1, g3, g2]) # should not throw an error
@parametrize_dtype
def test_batch_keeps_empty_data(idtype):
g1 = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g1.ndata["nh"] = F.tensor([])
g1.edata["eh"] = F.tensor([])
g2 = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g2.ndata["nh"] = F.tensor([])
g2.edata["eh"] = F.tensor([])
g = dgl.batch([g1, g2])
assert "nh" in g.ndata
assert "eh" in g.edata
def _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs): def _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs):
"""Internal function to compute batch information for subgraphs. """Internal function to compute batch information for subgraphs.
Parameters Parameters
......
...@@ -321,6 +321,18 @@ def test_unbatch2(idtype): ...@@ -321,6 +321,18 @@ def test_unbatch2(idtype):
check_graph_equal(g2, gg2) check_graph_equal(g2, gg2)
check_graph_equal(g3, gg3) check_graph_equal(g3, gg3)
@parametrize_dtype
def test_batch_keeps_empty_data(idtype):
g1 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g1.nodes["a"].data["nh"] = F.tensor([])
g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g2 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g2.nodes["a"].data["nh"] = F.tensor([])
g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g = dgl.batch([g1, g2])
assert "nh" in g.nodes["a"].data
assert "eh" in g.edges[("a", "to", "a")].data
if __name__ == '__main__': if __name__ == '__main__':
#test_topology('int32') #test_topology('int32')
#test_batching_batched('int32') #test_batching_batched('int32')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment