Unverified Commit a566b60b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent d1827488
......@@ -14,26 +14,26 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
else:
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
FunctionHandle = ctypes.c_void_p
......
......@@ -9,12 +9,12 @@ import numpy as np
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
dgl_shape_index_t,
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -24,29 +24,29 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
else:
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
......
......@@ -7,7 +7,7 @@ import sys
from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object
from .object_generic import convert_to_object, ObjectGeneric
# pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -16,15 +16,12 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
from ._cy3.core import _register_object, ObjectBase as _ObjectBase
else:
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
from ._cy2.core import _register_object, ObjectBase as _ObjectBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
def _new_object(cls):
......
"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from . import backend as F
from .base import ALL, is_all, DGLError, NID, EID
from .heterograph_index import disjoint_union, slice_gidx
from . import backend as F, convert, utils
from .base import ALL, DGLError, EID, is_all, NID
from .heterograph import DGLGraph
from . import convert
from . import utils
from .heterograph_index import disjoint_union, slice_gidx
__all__ = ["batch", "unbatch", "slice_batch"]
__all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
......@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
unbatch
"""
if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.')
raise DGLError("The input list of graphs cannot be empty.")
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata)))
raise DGLError(
"Invalid argument ndata: must be a string list but got {}.".format(
type(ndata)
)
)
if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format(
type(edata)))
raise DGLError(
"Invalid argument edata: must be a string list but got {}.".format(
type(edata)
)
)
if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.")
......@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs])
gidx = disjoint_union(
graphs[0]._graph.metagraph, [g._graph for g in graphs]
)
retg = DGLGraph(gidx, ntypes, etypes)
# Compute batch num nodes
......@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
# Batch node feature
if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs)
all_empty = all(
g._graph.number_of_nodes(ntype_id) == 0 for g in graphs
)
frames = [
g._node_frames[ntype_id] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty]
g._node_frames[ntype_id]
for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype))
ret_feat = _batch_feat_dicts(
frames, ndata, 'nodes["{}"].data'.format(ntype)
)
retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature
if edata is not None:
for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs)
all_empty = all(
g._graph.number_of_edges(etype_id) == 0 for g in graphs
)
frames = [
g._edge_frames[etype_id] for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty]
g._edge_frames[etype_id]
for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype))
ret_feat = _batch_feat_dicts(
frames, edata, "edges[{}].data".format(etype)
)
retg.edges[etype].data.update(ret_feat)
return retg
def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries.
......@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
else:
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features
ret_feat = {k : F.cat([fd[k] for fd in frames], 0) for k in keys}
ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat
def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones.
......@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
num_split = None
# Parse node_split
if node_split is None:
node_split = {ntype : g.batch_num_nodes(ntype) for ntype in g.ntypes}
node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1:
raise DGLError('Must provide a dictionary for argument node_split when'
' there are multiple node types.')
node_split = {g.ntypes[0] : node_split}
raise DGLError(
"Must provide a dictionary for argument node_split when"
" there are multiple node types."
)
node_split = {g.ntypes[0]: node_split}
if node_split.keys() != set(g.ntypes):
raise DGLError('Must specify node_split for each node type.')
raise DGLError("Must specify node_split for each node type.")
for split in node_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All node_split and edge_split must specify the same number'
' of split sizes.')
raise DGLError(
"All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
# Parse edge_split
if edge_split is None:
edge_split = {etype : g.batch_num_edges(etype) for etype in g.canonical_etypes}
edge_split = {
etype: g.batch_num_edges(etype) for etype in g.canonical_etypes
}
elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1:
raise DGLError('Must provide a dictionary for argument edge_split when'
' there are multiple edge types.')
edge_split = {g.canonical_etypes[0] : edge_split}
raise DGLError(
"Must provide a dictionary for argument edge_split when"
" there are multiple edge types."
)
edge_split = {g.canonical_etypes[0]: edge_split}
if edge_split.keys() != set(g.canonical_etypes):
raise DGLError('Must specify edge_split for each canonical edge type.')
raise DGLError("Must specify edge_split for each canonical edge type.")
for split in edge_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All edge_split and edge_split must specify the same number'
' of split sizes.')
raise DGLError(
"All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
node_split = {k : F.asnumpy(split).tolist() for k, split in node_split.items()}
edge_split = {k : F.asnumpy(split).tolist() for k, split in edge_split.items()}
node_split = {
k: F.asnumpy(split).tolist() for k, split in node_split.items()
}
edge_split = {
k: F.asnumpy(split).tolist() for k, split in edge_split.items()
}
# Split edges for each relation
edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
u, v = g.edges(order='eid', etype=rel)
u, v = g.edges(order="eid", etype=rel)
us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [{k : split[i] for k, split in node_split.items()}
for i in range(num_split)]
num_nodes_dict_per = [
{k: split[i] for k, split in node_split.items()}
for i in range(num_split)
]
# Create graphs
gs = [convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)]
gs = [
convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)
]
# Unbatch node features
for ntype in g.ntypes:
......@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
return gs
def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs.
......@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0:
start_nid.append(0)
else:
start_nid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0)))
start_nid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))
)
start_eid = []
num_edges = []
......@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0:
start_eid.append(0)
else:
start_eid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0)))
start_eid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))
)
# Slice graph structure
gidx = slice_gidx(g._graph, utils.toindex(num_nodes), utils.toindex(start_nid),
utils.toindex(num_edges), utils.toindex(start_eid))
gidx = slice_gidx(
g._graph,
utils.toindex(num_nodes),
utils.toindex(start_nid),
utils.toindex(num_edges),
utils.toindex(start_eid),
)
retg = DGLGraph(gidx, g.ntypes, g.etypes)
# Slice node features
for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid+num_nodes[ntid])
subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats
if store_ids:
retg.nodes[ntype].data[NID] = F.arange(stnid, stnid+num_nodes[ntid],
retg.idtype, retg.device)
retg.nodes[ntype].data[NID] = F.arange(
stnid, stnid + num_nodes[ntid], retg.idtype, retg.device
)
# Slice edge features
for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid]
for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid+num_edges[etid])
subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])
retg.edges[etype].data[key] = subfeats
if store_ids:
retg.edges[etype].data[EID] = F.arange(steid, steid+num_edges[etid],
retg.idtype, retg.device)
retg.edges[etype].data[EID] = F.arange(
steid, steid + num_edges[etid], retg.idtype, retg.device
)
return retg
This diff is collapsed.
......@@ -2,10 +2,8 @@
# pylint: disable=not-callable
import numpy as np
from . import backend as F
from . import function as fn
from . import ops
from .base import ALL, EID, NID, DGLError, dgl_warning, is_all
from . import backend as F, function as fn, ops
from .base import ALL, dgl_warning, DGLError, EID, is_all, NID
from .frame import Frame
from .udf import EdgeBatch, NodeBatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment