Unverified Commit a566b60b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent d1827488
...@@ -14,26 +14,26 @@ try: ...@@ -14,26 +14,26 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ( from ._cy3.core import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
else: else:
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ( from ._cy2.core import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import ( from ._ctypes.function import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
......
...@@ -9,12 +9,12 @@ import numpy as np ...@@ -9,12 +9,12 @@ import numpy as np
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import ( from .runtime_ctypes import (
dgl_shape_index_t,
DGLArray, DGLArray,
DGLArrayHandle, DGLArrayHandle,
DGLContext, DGLContext,
DGLDataType, DGLDataType,
TypeCode, TypeCode,
dgl_shape_index_t,
) )
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -24,29 +24,29 @@ try: ...@@ -24,29 +24,29 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ( from ._cy3.core import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
else: else:
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ( from ._cy2.core import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import ( from ._ctypes.ndarray import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object from .object_generic import convert_to_object, ObjectGeneric
# pylint: disable=invalid-name # pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -16,15 +16,12 @@ try: ...@@ -16,15 +16,12 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object, ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else: else:
from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object, ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls): def _new_object(cls):
......
"""Utilities for batching/unbatching graphs.""" """Utilities for batching/unbatching graphs."""
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F, convert, utils
from .base import ALL, is_all, DGLError, NID, EID from .base import ALL, DGLError, EID, is_all, NID
from .heterograph_index import disjoint_union, slice_gidx
from .heterograph import DGLGraph from .heterograph import DGLGraph
from . import convert from .heterograph_index import disjoint_union, slice_gidx
from . import utils
__all__ = ["batch", "unbatch", "slice_batch"]
__all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL): def batch(graphs, ndata=ALL, edata=ALL):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
...@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
unbatch unbatch
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError("The input list of graphs cannot be empty.")
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None): if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format( raise DGLError(
type(ndata))) "Invalid argument ndata: must be a string list but got {}.".format(
type(ndata)
)
)
if not (is_all(edata) or isinstance(edata, list) or edata is None): if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format( raise DGLError(
type(edata))) "Invalid argument edata: must be a string list but got {}.".format(
type(edata)
)
)
if any(g.is_block for g in graphs): if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.") raise DGLError("Batching a MFG is not supported.")
...@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes] ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations] etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs]) gidx = disjoint_union(
graphs[0]._graph.metagraph, [g._graph for g in graphs]
)
retg = DGLGraph(gidx, ntypes, etypes) retg = DGLGraph(gidx, ntypes, etypes)
# Compute batch num nodes # Compute batch num nodes
...@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
# Batch node feature # Batch node feature
if ndata is not None: if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes): for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs) all_empty = all(
g._graph.number_of_nodes(ntype_id) == 0 for g in graphs
)
frames = [ frames = [
g._node_frames[ntype_id] for g in graphs g._node_frames[ntype_id]
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype)) ret_feat = _batch_feat_dicts(
frames, ndata, 'nodes["{}"].data'.format(ntype)
)
retg.nodes[ntype].data.update(ret_feat) retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature # Batch edge feature
if edata is not None: if edata is not None:
for etype_id, etype in zip(relation_ids, relations): for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs) all_empty = all(
g._graph.number_of_edges(etype_id) == 0 for g in graphs
)
frames = [ frames = [
g._edge_frames[etype_id] for g in graphs g._edge_frames[etype_id]
if g._graph.number_of_edges(etype_id) > 0 or all_empty] for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype)) ret_feat = _batch_feat_dicts(
frames, edata, "edges[{}].data".format(etype)
)
retg.edges[etype].data.update(ret_feat) retg.edges[etype].data.update(ret_feat)
return retg return retg
def _batch_feat_dicts(frames, keys, feat_dict_name): def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries. """Internal function to batch feature dictionaries.
...@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name): ...@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
else: else:
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name) utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features # concat features
ret_feat = {k : F.cat([fd[k] for fd in frames], 0) for k in keys} ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat return ret_feat
def unbatch(g, node_split=None, edge_split=None): def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones. """Revert the batch operation by split the given graph into a list of small ones.
...@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None): ...@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
num_split = None num_split = None
# Parse node_split # Parse node_split
if node_split is None: if node_split is None:
node_split = {ntype : g.batch_num_nodes(ntype) for ntype in g.ntypes} node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping): elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1: if len(g.ntypes) != 1:
raise DGLError('Must provide a dictionary for argument node_split when' raise DGLError(
' there are multiple node types.') "Must provide a dictionary for argument node_split when"
node_split = {g.ntypes[0] : node_split} " there are multiple node types."
)
node_split = {g.ntypes[0]: node_split}
if node_split.keys() != set(g.ntypes): if node_split.keys() != set(g.ntypes):
raise DGLError('Must specify node_split for each node type.') raise DGLError("Must specify node_split for each node type.")
for split in node_split.values(): for split in node_split.values():
if num_split is not None and num_split != len(split): if num_split is not None and num_split != len(split):
raise DGLError('All node_split and edge_split must specify the same number' raise DGLError(
' of split sizes.') "All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split) num_split = len(split)
# Parse edge_split # Parse edge_split
if edge_split is None: if edge_split is None:
edge_split = {etype : g.batch_num_edges(etype) for etype in g.canonical_etypes} edge_split = {
etype: g.batch_num_edges(etype) for etype in g.canonical_etypes
}
elif not isinstance(edge_split, Mapping): elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1: if len(g.etypes) != 1:
raise DGLError('Must provide a dictionary for argument edge_split when' raise DGLError(
' there are multiple edge types.') "Must provide a dictionary for argument edge_split when"
edge_split = {g.canonical_etypes[0] : edge_split} " there are multiple edge types."
)
edge_split = {g.canonical_etypes[0]: edge_split}
if edge_split.keys() != set(g.canonical_etypes): if edge_split.keys() != set(g.canonical_etypes):
raise DGLError('Must specify edge_split for each canonical edge type.') raise DGLError("Must specify edge_split for each canonical edge type.")
for split in edge_split.values(): for split in edge_split.values():
if num_split is not None and num_split != len(split): if num_split is not None and num_split != len(split):
raise DGLError('All edge_split and edge_split must specify the same number' raise DGLError(
' of split sizes.') "All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split) num_split = len(split)
node_split = {k : F.asnumpy(split).tolist() for k, split in node_split.items()} node_split = {
edge_split = {k : F.asnumpy(split).tolist() for k, split in edge_split.items()} k: F.asnumpy(split).tolist() for k, split in node_split.items()
}
edge_split = {
k: F.asnumpy(split).tolist() for k, split in edge_split.items()
}
# Split edges for each relation # Split edges for each relation
edge_dict_per = [{} for i in range(num_split)] edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes: for rel in g.canonical_etypes:
srctype, etype, dsttype = rel srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0 srcnid_off = dstnid_off = 0
u, v = g.edges(order='eid', etype=rel) u, v = g.edges(order="eid", etype=rel)
us = F.split(u, edge_split[rel], 0) us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0) vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)): for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off) edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i] srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i] dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [{k : split[i] for k, split in node_split.items()} num_nodes_dict_per = [
for i in range(num_split)] {k: split[i] for k, split in node_split.items()}
for i in range(num_split)
]
# Create graphs # Create graphs
gs = [convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype) gs = [
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)] convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)
]
# Unbatch node features # Unbatch node features
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None): ...@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
return gs return gs
def slice_batch(g, gid, store_ids=False): def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs. """Get a particular graph from a batch of graphs.
...@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False): ...@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0: if gid == 0:
start_nid.append(0) start_nid.append(0)
else: else:
start_nid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))) start_nid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))
)
start_eid = [] start_eid = []
num_edges = [] num_edges = []
...@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False): ...@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0: if gid == 0:
start_eid.append(0) start_eid.append(0)
else: else:
start_eid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))) start_eid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))
)
# Slice graph structure # Slice graph structure
gidx = slice_gidx(g._graph, utils.toindex(num_nodes), utils.toindex(start_nid), gidx = slice_gidx(
utils.toindex(num_edges), utils.toindex(start_eid)) g._graph,
utils.toindex(num_nodes),
utils.toindex(start_nid),
utils.toindex(num_edges),
utils.toindex(start_eid),
)
retg = DGLGraph(gidx, g.ntypes, g.etypes) retg = DGLGraph(gidx, g.ntypes, g.etypes)
# Slice node features # Slice node features
for ntid, ntype in enumerate(g.ntypes): for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid] stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items(): for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid+num_nodes[ntid]) subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats retg.nodes[ntype].data[key] = subfeats
if store_ids: if store_ids:
retg.nodes[ntype].data[NID] = F.arange(stnid, stnid+num_nodes[ntid], retg.nodes[ntype].data[NID] = F.arange(
retg.idtype, retg.device) stnid, stnid + num_nodes[ntid], retg.idtype, retg.device
)
# Slice edge features # Slice edge features
for etid, etype in enumerate(g.canonical_etypes): for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid] steid = start_eid[etid]
for key, feat in g.edges[etype].data.items(): for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid+num_edges[etid]) subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])
retg.edges[etype].data[key] = subfeats retg.edges[etype].data[key] = subfeats
if store_ids: if store_ids:
retg.edges[etype].data[EID] = F.arange(steid, steid+num_edges[etid], retg.edges[etype].data[EID] = F.arange(
retg.idtype, retg.device) steid, steid + num_edges[etid], retg.idtype, retg.device
)
return retg return retg
This diff is collapsed.
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
# pylint: disable=not-callable # pylint: disable=not-callable
import numpy as np import numpy as np
from . import backend as F from . import backend as F, function as fn, ops
from . import function as fn from .base import ALL, dgl_warning, DGLError, EID, is_all, NID
from . import ops
from .base import ALL, EID, NID, DGLError, dgl_warning, is_all
from .frame import Frame from .frame import Frame
from .udf import EdgeBatch, NodeBatch from .udf import EdgeBatch, NodeBatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment