Unverified Commit 00edb949 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Performance] Accelerate batching (#2363)

* speed up batching

* more fix

* lint

* fix
parent 58775ada
......@@ -168,7 +168,9 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
raise DGLError("Batching a block is not supported.")
relations = list(sorted(graphs[0].canonical_etypes))
relation_ids = [graphs[0].get_etype_id(r) for r in relations]
ntypes = list(sorted(graphs[0].ntypes))
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs])
......@@ -188,27 +190,35 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch node feature
if ndata is not None:
for ntype in ntypes:
feat_dicts = [g.nodes[ntype].data for g in graphs if g.number_of_nodes(ntype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, ndata, 'nodes["{}"].data'.format(ntype))
for ntype_id, ntype in zip(ntype_ids, ntypes):
frames = [
g._node_frames[ntype_id] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype))
retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature
if edata is not None:
for etype in relations:
feat_dicts = [g.edges[etype].data for g in graphs if g.number_of_edges(etype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, edata, 'edges[{}].data'.format(etype))
for etype_id, etype in zip(relation_ids, relations):
frames = [
g._edge_frames[etype_id] for g in graphs
if g._graph.number_of_edges(etype_id) > 0]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype))
retg.edges[etype].data.update(ret_feat)
return retg
def _batch_feat_dicts(feat_dicts, keys, feat_dict_name):
def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries.
Parameters
----------
feat_dicts : list[dict[str, Tensor]]
Feature dictionary list.
frames : list[Frame]
List of frames
keys : list[str]
Feature keys. Can be '__ALL__', meaning batching all features.
feat_dict_name : str
......@@ -219,17 +229,17 @@ def _batch_feat_dicts(feat_dicts, keys, feat_dict_name):
dict[str, Tensor]
New feature dict.
"""
if len(feat_dicts) == 0:
if len(frames) == 0:
return {}
schemas = [frame.schemes for frame in frames]
# sanity checks
if is_all(keys):
utils.check_all_same_keys(feat_dicts, feat_dict_name)
keys = feat_dicts[0].keys()
utils.check_all_same_schema(schemas, feat_dict_name)
keys = schemas[0].keys()
else:
utils.check_all_have_keys(feat_dicts, keys, feat_dict_name)
utils.check_all_same_schema(feat_dicts, keys, feat_dict_name)
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features
ret_feat = {k : F.cat([fd[k] for fd in feat_dicts], 0) for k in keys}
ret_feat = {k : F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat
def unbatch(g, node_split=None, edge_split=None):
......
......@@ -119,51 +119,41 @@ def check_all_same_device(glist, name):
raise DGLError('Expect {}[{}] to be on device {}, but got {}.'.format(
name, i, device, g.device))
def check_all_same_keys(dict_list, name):
"""Check all the dictionaries have the same set of keys."""
if len(dict_list) == 0:
def check_all_same_schema(schemas, name):
"""Check the list of schemas are the same."""
if len(schemas) == 0:
return
keys = dict_list[0].keys()
for dct in dict_list:
if keys != dct.keys():
raise DGLError('Expect all {} to have the same set of keys, but got'
' {} and {}.'.format(name, keys, dct.keys()))
def check_all_have_keys(dict_list, keys, name):
"""Check the dictionaries all have the given keys."""
if len(dict_list) == 0:
return
keys = set(keys)
for dct in dict_list:
if not keys.issubset(dct.keys()):
raise DGLError('Expect all {} to include keys {}, but got {}.'.format(
name, keys, dct.keys()))
def check_all_same_schema(feat_dict_list, keys, name):
"""Check the features of the given keys all have the same schema.
Suggest calling ``check_all_have_keys`` first.
for i, schema in enumerate(schemas):
if schema != schemas[0]:
raise DGLError(
'Expect all graphs to have the same schema on {}, '
'but graph {} got\n\t{}\nwhich is different from\n\t{}.'.format(
name, i, schema, schemas[0]))
Parameters
----------
feat_dict_list : list[dict[str, Tensor]]
Feature dictionaries.
keys : list[str]
Keys
name : str
Name of this feature dict.
"""
if len(feat_dict_list) == 0:
def check_all_same_schema_for_keys(schemas, keys, name):
"""Check the list of schemas are the same on the given keys."""
if len(schemas) == 0:
return
for fdict in feat_dict_list:
for k in keys:
t1 = feat_dict_list[0][k]
t2 = fdict[k]
if F.dtype(t1) != F.dtype(t2) or F.shape(t1)[1:] != F.shape(t2)[1:]:
raise DGLError('Expect all features {}["{}"] to have the same data type'
' and feature size, but got\n\t{} {}\nand\n\t{} {}.'.format(
name, k, F.dtype(t1), F.shape(t1)[1:],
F.dtype(t2), F.shape(t2)[1:]))
head = None
keys = set(keys)
for i, schema in enumerate(schemas):
if not keys.issubset(schema.keys()):
raise DGLError(
'Expect all graphs to have keys {} on {}, '
'but graph {} got keys {}.'.format(
keys, name, i, schema.keys()))
if head is None:
head = {k: schema[k] for k in keys}
else:
target = {k: schema[k] for k in keys}
if target != head:
raise DGLError(
'Expect all graphs to have the same schema for keys {} on {}, '
'but graph {} got \n\t{}\n which is different from\n\t{}.'.format(
keys, name, i, target, head))
def check_valid_idtype(idtype):
"""Check whether the value of the idtype argument is valid (int32/int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment