Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
"""Built-in function base class""" """Built-in function base class"""
from __future__ import absolute_import from __future__ import absolute_import
__all__ = ['BuiltinFunction', 'TargetCode'] __all__ = ["BuiltinFunction", "TargetCode"]
class TargetCode(object): class TargetCode(object):
...@@ -10,6 +10,7 @@ class TargetCode(object): ...@@ -10,6 +10,7 @@ class TargetCode(object):
Note: must be consistent with the target code definition in C++ side: Note: must be consistent with the target code definition in C++ side:
src/kernel/binary_reduce_common.h src/kernel/binary_reduce_common.h
""" """
SRC = 0 SRC = 0
DST = 1 DST = 1
EDGE = 2 EDGE = 2
...@@ -23,6 +24,7 @@ class TargetCode(object): ...@@ -23,6 +24,7 @@ class TargetCode(object):
class BuiltinFunction(object): class BuiltinFunction(object):
"""Base builtin function class.""" """Base builtin function class."""
@property @property
def name(self): def name(self):
"""Return the name of this builtin function.""" """Return the name of this builtin function."""
......
...@@ -4,16 +4,15 @@ from __future__ import absolute_import ...@@ -4,16 +4,15 @@ from __future__ import absolute_import
import sys import sys
from .base import BuiltinFunction, TargetCode
from .._deprecate.runtime import ir from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
from .base import BuiltinFunction, TargetCode
class ReduceFunction(BuiltinFunction): class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class.""" """Base builtin reduce function class."""
def _invoke(self, graph, edge_frame, out_size, edge_map=None, def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
out_map=None):
"""Symbolic computation of this builtin function to create """Symbolic computation of this builtin function to create
runtime.executor runtime.executor
""" """
...@@ -28,21 +27,28 @@ class ReduceFunction(BuiltinFunction): ...@@ -28,21 +27,28 @@ class ReduceFunction(BuiltinFunction):
class SimpleReduceFunction(ReduceFunction): class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another """Builtin reduce function that aggregates a single field into another
single field.""" single field."""
def __init__(self, name, msg_field, out_field): def __init__(self, name, msg_field, out_field):
self._name = name self._name = name
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def _invoke(self, graph, edge_frame, out_size, edge_map=None, def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
out_map=None):
"""Symbolic execution of this builtin function""" """Symbolic execution of this builtin function"""
reducer = self._name reducer = self._name
graph = var.GRAPH(graph) graph = var.GRAPH(graph)
edge_map = var.MAP(edge_map) edge_map = var.MAP(edge_map)
out_map = var.MAP(out_map) out_map = var.MAP(out_map)
edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field)) edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field))
return ir.COPY_REDUCE(reducer, graph, TargetCode.EDGE, edge_data, return ir.COPY_REDUCE(
out_size, edge_map, out_map) reducer,
graph,
TargetCode.EDGE,
edge_data,
out_size,
edge_map,
out_map,
)
@property @property
def name(self): def name(self):
...@@ -53,6 +59,7 @@ class SimpleReduceFunction(ReduceFunction): ...@@ -53,6 +59,7 @@ class SimpleReduceFunction(ReduceFunction):
# Generate all following reducer functions: # Generate all following reducer functions:
# sum, max, min, mean, prod # sum, max, min, mean, prod
def _gen_reduce_builtin(reducer): def _gen_reduce_builtin(reducer):
docstring = """Builtin reduce function that aggregates messages by {0}. docstring = """Builtin reduce function that aggregates messages by {0}.
...@@ -73,10 +80,13 @@ def _gen_reduce_builtin(reducer): ...@@ -73,10 +80,13 @@ def _gen_reduce_builtin(reducer):
>>> import torch >>> import torch
>>> def reduce_func(nodes): >>> def reduce_func(nodes):
>>> return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}} >>> return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}}
""".format(reducer) """.format(
reducer
)
def func(msg, out): def func(msg, out):
return SimpleReduceFunction(reducer, msg, out) return SimpleReduceFunction(reducer, msg, out)
func.__name__ = str(reducer) func.__name__ = str(reducer)
func.__qualname__ = str(reducer) func.__qualname__ = str(reducer)
func.__doc__ = docstring func.__doc__ = docstring
......
"""Module for various graph generator functions.""" """Module for various graph generator functions."""
from . import backend as F from . import backend as F
from . import convert from . import convert, random
from . import random
__all__ = ["rand_graph", "rand_bipartite"]
__all__ = ['rand_graph', 'rand_bipartite']
def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()): def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
"""Generate a random graph of the given number of nodes/edges and return. """Generate a random graph of the given number of nodes/edges and return.
...@@ -46,20 +46,28 @@ def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()): ...@@ -46,20 +46,28 @@ def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
ndata_schemes={} ndata_schemes={}
edata_schemes={}) edata_schemes={})
""" """
#TODO(minjie): support RNG as one of the arguments. # TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_nodes * num_nodes, num_edges, replace=False) eids = random.choice(num_nodes * num_nodes, num_edges, replace=False)
eids = F.zerocopy_to_numpy(eids) eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_nodes) rows = F.zerocopy_from_numpy(eids // num_nodes)
cols = F.zerocopy_from_numpy(eids % num_nodes) cols = F.zerocopy_from_numpy(eids % num_nodes)
rows = F.copy_to(F.astype(rows, idtype), device) rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device) cols = F.copy_to(F.astype(cols, idtype), device)
return convert.graph((rows, cols), return convert.graph(
num_nodes=num_nodes, (rows, cols), num_nodes=num_nodes, idtype=idtype, device=device
idtype=idtype, device=device) )
def rand_bipartite(utype, etype, vtype,
num_src_nodes, num_dst_nodes, num_edges, def rand_bipartite(
idtype=F.int64, device=F.cpu()): utype,
etype,
vtype,
num_src_nodes,
num_dst_nodes,
num_edges,
idtype=F.int64,
device=F.cpu(),
):
"""Generate a random uni-directional bipartite graph and return. """Generate a random uni-directional bipartite graph and return.
It uniformly chooses ``num_edges`` from all possible node pairs and form a graph. It uniformly chooses ``num_edges`` from all possible node pairs and form a graph.
...@@ -107,13 +115,18 @@ def rand_bipartite(utype, etype, vtype, ...@@ -107,13 +115,18 @@ def rand_bipartite(utype, etype, vtype,
num_edges={('user', 'buys', 'game'): 10}, num_edges={('user', 'buys', 'game'): 10},
metagraph=[('user', 'game', 'buys')]) metagraph=[('user', 'game', 'buys')])
""" """
#TODO(minjie): support RNG as one of the arguments. # TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_src_nodes * num_dst_nodes, num_edges, replace=False) eids = random.choice(
num_src_nodes * num_dst_nodes, num_edges, replace=False
)
eids = F.zerocopy_to_numpy(eids) eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_dst_nodes) rows = F.zerocopy_from_numpy(eids // num_dst_nodes)
cols = F.zerocopy_from_numpy(eids % num_dst_nodes) cols = F.zerocopy_from_numpy(eids % num_dst_nodes)
rows = F.copy_to(F.astype(rows, idtype), device) rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device) cols = F.copy_to(F.astype(cols, idtype), device)
return convert.heterograph({(utype, etype, vtype): (rows, cols)}, return convert.heterograph(
{(utype, etype, vtype): (rows, cols)},
{utype: num_src_nodes, vtype: num_dst_nodes}, {utype: num_src_nodes, vtype: num_dst_nodes},
idtype=idtype, device=device) idtype=idtype,
device=device,
)
...@@ -8,5 +8,5 @@ ...@@ -8,5 +8,5 @@
This package is experimental and the interfaces may be subject This package is experimental and the interfaces may be subject
to changes in future releases. to changes in future releases.
""" """
from .fps import *
from .edge_coarsening import * from .edge_coarsening import *
from .fps import *
"""Python interfaces to DGL farthest point sampler.""" """Python interfaces to DGL farthest point sampler."""
import numpy as np import numpy as np
from .._ffi.base import DGLError
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
from .._ffi.base import DGLError
from .._ffi.function import _init_api
def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result): def _farthest_point_sampler(
data, batch_size, sample_points, dist, start_idx, result
):
r"""Farthest Point Sampler r"""Farthest Point Sampler
Parameters Parameters
...@@ -32,14 +35,19 @@ def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, re ...@@ -32,14 +35,19 @@ def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, re
assert F.shape(data)[0] >= sample_points * batch_size assert F.shape(data)[0] >= sample_points * batch_size
assert F.shape(data)[0] % batch_size == 0 assert F.shape(data)[0] % batch_size == 0
_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data), _CAPI_FarthestPointSampler(
batch_size, sample_points, F.zerocopy_to_dgl_ndarray(data),
batch_size,
sample_points,
F.zerocopy_to_dgl_ndarray(dist), F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx), F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result)) F.zerocopy_to_dgl_ndarray(result),
)
def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True): def _neighbor_matching(
graph_idx, num_nodes, edge_weights=None, relabel_idx=True
):
""" """
Description Description
----------- -----------
...@@ -82,7 +90,11 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True ...@@ -82,7 +90,11 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
if edge_weights is not None: if edge_weights is not None:
edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights) edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights)
node_label = F.full_1d( node_label = F.full_1d(
num_nodes, -1, getattr(F, graph_idx.dtype), F.to_backend_ctx(graph_idx.ctx)) num_nodes,
-1,
getattr(F, graph_idx.dtype),
F.to_backend_ctx(graph_idx.ctx),
)
node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label) node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label)
_CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi) _CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi)
if F.reduce_sum(node_label < 0).item() != 0: if F.reduce_sum(node_label < 0).item() != 0:
...@@ -99,4 +111,4 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True ...@@ -99,4 +111,4 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
return node_label return node_label
_init_api('dgl.geometry', __name__) _init_api("dgl.geometry", __name__)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
from .. import remove_self_loop from .. import remove_self_loop
from .capi import _neighbor_matching from .capi import _neighbor_matching
__all__ = ['neighbor_matching'] __all__ = ["neighbor_matching"]
def neighbor_matching(graph, e_weights=None, relabel_idx=True): def neighbor_matching(graph, e_weights=None, relabel_idx=True):
r""" r"""
...@@ -48,13 +49,16 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True): ...@@ -48,13 +49,16 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True):
>>> res = neighbor_matching(g) >>> res = neighbor_matching(g)
tensor([0, 1, 1]) tensor([0, 1, 1])
""" """
assert graph.is_homogeneous, \ assert (
"The graph used in graph node matching must be homogeneous" graph.is_homogeneous
), "The graph used in graph node matching must be homogeneous"
if e_weights is not None: if e_weights is not None:
graph.edata['e_weights'] = e_weights graph.edata["e_weights"] = e_weights
graph = remove_self_loop(graph) graph = remove_self_loop(graph)
e_weights = graph.edata['e_weights'] e_weights = graph.edata["e_weights"]
graph.edata.pop('e_weights') graph.edata.pop("e_weights")
else: else:
graph = remove_self_loop(graph) graph = remove_self_loop(graph)
return _neighbor_matching(graph._graph, graph.num_nodes(), e_weights, relabel_idx) return _neighbor_matching(
graph._graph, graph.num_nodes(), e_weights, relabel_idx
)
"""Farthest Point Sampler for pytorch Geometry package""" """Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from .capi import _farthest_point_sampler from .capi import _farthest_point_sampler
__all__ = ['farthest_point_sampler'] __all__ = ["farthest_point_sampler"]
def farthest_point_sampler(pos, npoints, start_idx=None): def farthest_point_sampler(pos, npoints, start_idx=None):
...@@ -50,12 +49,16 @@ def farthest_point_sampler(pos, npoints, start_idx=None): ...@@ -50,12 +49,16 @@ def farthest_point_sampler(pos, npoints, start_idx=None):
pos = pos.reshape(-1, C) pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx) dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None: if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64, start_idx = F.randint(
ctx=ctx, low=0, high=N-1) shape=(B,), dtype=F.int64, ctx=ctx, low=0, high=N - 1
)
else: else:
if start_idx >= N or start_idx < 0: if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format( raise DGLError(
N, start_idx)) "Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx
)
)
start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx) start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx) result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result) _farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
......
"""Module for global configuration operators.""" """Module for global configuration operators."""
from ._ffi.function import _init_api from ._ffi.function import _init_api
__all__ = ["is_libxsmm_enabled", "use_libxsmm"] __all__ = ["is_libxsmm_enabled", "use_libxsmm"]
def use_libxsmm(flag): def use_libxsmm(flag):
r"""Set whether DGL uses libxsmm at runtime. r"""Set whether DGL uses libxsmm at runtime.
...@@ -21,6 +21,7 @@ def use_libxsmm(flag): ...@@ -21,6 +21,7 @@ def use_libxsmm(flag):
""" """
_CAPI_DGLConfigSetLibxsmm(flag) _CAPI_DGLConfigSetLibxsmm(flag)
def is_libxsmm_enabled(): def is_libxsmm_enabled():
r"""Get whether the use_libxsmm flag is turned on. r"""Get whether the use_libxsmm flag is turned on.
...@@ -35,4 +36,5 @@ def is_libxsmm_enabled(): ...@@ -35,4 +36,5 @@ def is_libxsmm_enabled():
""" """
return _CAPI_DGLConfigGetLibxsmm() return _CAPI_DGLConfigGetLibxsmm()
_init_api("dgl.global_config") _init_api("dgl.global_config")
"""Module for graph index class definition.""" """Module for graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
import scipy import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import utils from . import utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning
class BoolFlag(object): class BoolFlag(object):
"""Bool flag with unknown value""" """Bool flag with unknown value"""
BOOL_UNKNOWN = -1 BOOL_UNKNOWN = -1
BOOL_FALSE = 0 BOOL_FALSE = 0
BOOL_TRUE = 1 BOOL_TRUE = 1
@register_object('graph.Graph')
@register_object("graph.Graph")
class GraphIndex(ObjectBase): class GraphIndex(ObjectBase):
"""Graph index object. """Graph index object.
...@@ -33,6 +36,7 @@ class GraphIndex(ObjectBase): ...@@ -33,6 +36,7 @@ class GraphIndex(ObjectBase):
- `dgl.graph_index.from_csr` - `dgl.graph_index.from_csr`
- `dgl.graph_index.from_coo` - `dgl.graph_index.from_coo`
""" """
def __new__(cls): def __new__(cls):
obj = ObjectBase.__new__(cls) obj = ObjectBase.__new__(cls)
obj._readonly = None # python-side cache of the flag obj._readonly = None # python-side cache of the flag
...@@ -53,13 +57,15 @@ class GraphIndex(ObjectBase): ...@@ -53,13 +57,15 @@ class GraphIndex(ObjectBase):
# Pickle compatibility check # Pickle compatibility check
# TODO: we should store a storage version number in later releases. # TODO: we should store a storage version number in later releases.
if isinstance(state, tuple) and len(state) == 5: if isinstance(state, tuple) and len(state) == 5:
dgl_warning("The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3") dgl_warning(
"The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3"
)
num_nodes, _, readonly, src, dst = state num_nodes, _, readonly, src, dst = state
elif isinstance(state, tuple) and len(state) == 4: elif isinstance(state, tuple) and len(state) == 4:
# post-0.4.3. # post-0.4.3.
num_nodes, readonly, src, dst = state num_nodes, readonly, src, dst = state
else: else:
raise IOError('Unrecognized storage format.') raise IOError("Unrecognized storage format.")
self._cache = {} self._cache = {}
self._readonly = readonly self._readonly = readonly
...@@ -68,7 +74,8 @@ class GraphIndex(ObjectBase): ...@@ -68,7 +74,8 @@ class GraphIndex(ObjectBase):
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(num_nodes), int(num_nodes),
readonly) readonly,
)
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. """Add nodes.
...@@ -240,7 +247,9 @@ class GraphIndex(ObjectBase): ...@@ -240,7 +247,9 @@ class GraphIndex(ObjectBase):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)) return utils.toindex(
_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)
)
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -257,8 +266,9 @@ class GraphIndex(ObjectBase): ...@@ -257,8 +266,9 @@ class GraphIndex(ObjectBase):
utils.Index utils.Index
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLGraphPredecessors( return utils.toindex(
self, int(v), int(radius))) _CAPI_DGLGraphPredecessors(self, int(v), int(radius))
)
def successors(self, v, radius=1): def successors(self, v, radius=1):
"""Return the successors of the node. """Return the successors of the node.
...@@ -275,8 +285,9 @@ class GraphIndex(ObjectBase): ...@@ -275,8 +285,9 @@ class GraphIndex(ObjectBase):
utils.Index utils.Index
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLGraphSuccessors( return utils.toindex(
self, int(v), int(radius))) _CAPI_DGLGraphSuccessors(self, int(v), int(radius))
)
def edge_id(self, u, v): def edge_id(self, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -432,7 +443,7 @@ class GraphIndex(ObjectBase): ...@@ -432,7 +443,7 @@ class GraphIndex(ObjectBase):
""" """
_CAPI_DGLSortAdj(self) _CAPI_DGLSortAdj(self)
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache="_cache", prefix="edges")
def edges(self, order=None): def edges(self, order=None):
"""Return all the edges """Return all the edges
...@@ -606,7 +617,7 @@ class GraphIndex(ObjectBase): ...@@ -606,7 +617,7 @@ class GraphIndex(ObjectBase):
e_array = e.todgltensor() e_array = e.todgltensor()
return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes) return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache="_cache", prefix="scipy_adj")
def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None): def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
...@@ -631,8 +642,10 @@ class GraphIndex(ObjectBase): ...@@ -631,8 +642,10 @@ class GraphIndex(ObjectBase):
The scipy representation of adjacency matrix. The scipy representation of adjacency matrix.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
if return_edge_ids is None: if return_edge_ids is None:
dgl_warning( dgl_warning(
...@@ -640,17 +653,24 @@ class GraphIndex(ObjectBase): ...@@ -640,17 +653,24 @@ class GraphIndex(ObjectBase):
" As a result there is one 0 entry which is not eliminated." " As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default," " In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.", " and 0 will be eliminated otherwise.",
FutureWarning) FutureWarning,
)
return_edge_ids = True return_edge_ids = True
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) data = (
utils.toindex(rst(2)).tonumpy()
if return_edge_ids
else np.ones_like(indices)
)
n = self.number_of_nodes() n = self.number_of_nodes()
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(n, n)) return scipy.sparse.csr_matrix(
elif fmt == 'coo': (data, indices, indptr), shape=(n, n)
)
elif fmt == "coo":
idx = utils.toindex(rst(0)).tonumpy() idx = utils.toindex(rst(0)).tonumpy()
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
...@@ -660,7 +680,7 @@ class GraphIndex(ObjectBase): ...@@ -660,7 +680,7 @@ class GraphIndex(ObjectBase):
else: else:
raise Exception("unknown format") raise Exception("unknown format")
@utils.cached_member(cache='_cache', prefix='immu_gidx') @utils.cached_member(cache="_cache", prefix="immu_gidx")
def get_immutable_gidx(self, ctx): def get_immutable_gidx(self, ctx):
"""Create an immutable graph index and copy to the given device context. """Create an immutable graph index and copy to the given device context.
...@@ -717,8 +737,10 @@ class GraphIndex(ObjectBase): ...@@ -717,8 +737,10 @@ class GraphIndex(ObjectBase):
if shuffle is not required. if shuffle is not required.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
fmt = F.get_preferred_sparse_format() fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
...@@ -726,8 +748,11 @@ class GraphIndex(ObjectBase): ...@@ -726,8 +748,11 @@ class GraphIndex(ObjectBase):
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx) indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2)) shuffle = utils.toindex(rst(2))
dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx) dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), spmat = F.sparse_matrix(
(self.number_of_nodes(), self.number_of_nodes()))[0] dat,
("csr", indices, indptr),
(self.number_of_nodes(), self.number_of_nodes()),
)[0]
return spmat, shuffle return spmat, shuffle
elif fmt == "coo": elif fmt == "coo":
## FIXME(minjie): data type ## FIXME(minjie): data type
...@@ -736,8 +761,10 @@ class GraphIndex(ObjectBase): ...@@ -736,8 +761,10 @@ class GraphIndex(ObjectBase):
idx = F.reshape(idx, (2, m)) idx = F.reshape(idx, (2, m))
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
n = self.number_of_nodes() n = self.number_of_nodes()
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n)) adj, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, n))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return adj, shuffle_idx return adj, shuffle_idx
else: else:
raise Exception("unknown format") raise Exception("unknown format")
...@@ -783,21 +810,21 @@ class GraphIndex(ObjectBase): ...@@ -783,21 +810,21 @@ class GraphIndex(ObjectBase):
eid = eid.tousertensor(ctx) # the index of the ctx will be cached eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
if typestr == 'in': if typestr == "in":
row = F.unsqueeze(dst, 0) row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'out': elif typestr == "out":
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'both': elif typestr == "both":
# first remove entries for self loops # first remove entries for self loops
mask = F.logical_not(F.equal(src, dst)) mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask) src = F.boolean_mask(src, mask)
...@@ -812,10 +839,12 @@ class GraphIndex(ObjectBase): ...@@ -812,10 +839,12 @@ class GraphIndex(ObjectBase):
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx) x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx) y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr)) raise DGLError("Invalid incidence matrix type: %s" % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return inc, shuffle_idx return inc, shuffle_idx
def to_networkx(self): def to_networkx(self):
...@@ -902,7 +931,9 @@ class GraphIndex(ObjectBase): ...@@ -902,7 +931,9 @@ class GraphIndex(ObjectBase):
GraphIndex GraphIndex
The graph index on the given device context. The graph index on the given device context.
""" """
return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id) return _CAPI_DGLImmutableGraphCopyTo(
self, ctx.device_type, ctx.device_id
)
def copyto_shared_mem(self, shared_mem_name): def copyto_shared_mem(self, shared_mem_name):
"""Copy this immutable graph index to shared memory. """Copy this immutable graph index to shared memory.
...@@ -939,7 +970,10 @@ class GraphIndex(ObjectBase): ...@@ -939,7 +970,10 @@ class GraphIndex(ObjectBase):
int int
The number of bits needed The number of bits needed
""" """
if self.number_of_edges() >= 0x80000000 or self.number_of_nodes() >= 0x80000000: if (
self.number_of_edges() >= 0x80000000
or self.number_of_nodes() >= 0x80000000
):
return 64 return 64
else: else:
return 32 return 32
...@@ -961,9 +995,11 @@ class GraphIndex(ObjectBase): ...@@ -961,9 +995,11 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits)) return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
@register_object('graph.Subgraph')
@register_object("graph.Subgraph")
class SubgraphIndex(ObjectBase): class SubgraphIndex(ObjectBase):
"""Subgraph data structure""" """Subgraph data structure"""
@property @property
def graph(self): def graph(self):
"""The subgraph structure """The subgraph structure
...@@ -1028,16 +1064,15 @@ def from_coo(num_nodes, src, dst, readonly): ...@@ -1028,16 +1064,15 @@ def from_coo(num_nodes, src, dst, readonly):
dst = utils.toindex(dst) dst = utils.toindex(dst)
if readonly: if readonly:
gidx = _CAPI_DGLGraphCreate( gidx = _CAPI_DGLGraphCreate(
src.todgltensor(), src.todgltensor(), dst.todgltensor(), int(num_nodes), readonly
dst.todgltensor(), )
int(num_nodes),
readonly)
else: else:
gidx = _CAPI_DGLGraphCreateMutable() gidx = _CAPI_DGLGraphCreateMutable()
gidx.add_nodes(num_nodes) gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst) gidx.add_edges(src, dst)
return gidx return gidx
def from_csr(indptr, indices, direction): def from_csr(indptr, indices, direction):
"""Load a graph from CSR arrays. """Load a graph from CSR arrays.
...@@ -1058,11 +1093,11 @@ def from_csr(indptr, indices, direction): ...@@ -1058,11 +1093,11 @@ def from_csr(indptr, indices, direction):
indptr = utils.toindex(indptr) indptr = utils.toindex(indptr)
indices = utils.toindex(indices) indices = utils.toindex(indices)
gidx = _CAPI_DGLGraphCSRCreate( gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(), indices.todgltensor(), direction
indices.todgltensor(), )
direction)
return gidx return gidx
def from_shared_mem_graph_index(shared_mem_name): def from_shared_mem_graph_index(shared_mem_name):
"""Load a graph index from the shared memory. """Load a graph index from the shared memory.
...@@ -1078,6 +1113,7 @@ def from_shared_mem_graph_index(shared_mem_name): ...@@ -1078,6 +1113,7 @@ def from_shared_mem_graph_index(shared_mem_name):
""" """
return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name) return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)
def from_networkx(nx_graph, readonly): def from_networkx(nx_graph, readonly):
"""Convert from networkx graph. """Convert from networkx graph.
...@@ -1107,7 +1143,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1107,7 +1143,7 @@ def from_networkx(nx_graph, readonly):
# nx_graph.edges(data=True) returns src, dst, attr_dict # nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0: if nx_graph.number_of_edges() > 0:
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1] has_edge_id = "id" in next(iter(nx_graph.edges(data=True)))[-1]
else: else:
has_edge_id = False has_edge_id = False
...@@ -1116,7 +1152,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1116,7 +1152,7 @@ def from_networkx(nx_graph, readonly):
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr['id'] eid = attr["id"]
src[eid] = u src[eid] = u
dst[eid] = v dst[eid] = v
else: else:
...@@ -1131,6 +1167,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1131,6 +1167,7 @@ def from_networkx(nx_graph, readonly):
dst = utils.toindex(dst) dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, readonly) return from_coo(num_nodes, src, dst, readonly)
def from_scipy_sparse_matrix(adj, readonly): def from_scipy_sparse_matrix(adj, readonly):
"""Convert from scipy sparse matrix. """Convert from scipy sparse matrix.
...@@ -1145,7 +1182,7 @@ def from_scipy_sparse_matrix(adj, readonly): ...@@ -1145,7 +1182,7 @@ def from_scipy_sparse_matrix(adj, readonly):
GraphIndex GraphIndex
The graph index. The graph index.
""" """
if adj.getformat() != 'csr' or not readonly: if adj.getformat() != "csr" or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1]) num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo() adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly) return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)
...@@ -1153,6 +1190,7 @@ def from_scipy_sparse_matrix(adj, readonly): ...@@ -1153,6 +1190,7 @@ def from_scipy_sparse_matrix(adj, readonly):
# If the input matrix is csr, we still treat it as multigraph. # If the input matrix is csr, we still treat it as multigraph.
return from_csr(adj.indptr, adj.indices, "out") return from_csr(adj.indptr, adj.indices, "out")
def from_edge_list(elist, readonly): def from_edge_list(elist, readonly):
"""Convert from an edge list. """Convert from an edge list.
...@@ -1172,6 +1210,7 @@ def from_edge_list(elist, readonly): ...@@ -1172,6 +1210,7 @@ def from_edge_list(elist, readonly):
num_nodes = max(src.max(), dst.max()) + 1 num_nodes = max(src.max(), dst.max()) + 1
return from_coo(num_nodes, src_ids, dst_ids, readonly) return from_coo(num_nodes, src_ids, dst_ids, readonly)
def map_to_subgraph_nid(induced_nodes, parent_nids): def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
...@@ -1188,8 +1227,12 @@ def map_to_subgraph_nid(induced_nodes, parent_nids): ...@@ -1188,8 +1227,12 @@ def map_to_subgraph_nid(induced_nodes, parent_nids):
utils.Index utils.Index
Node Ids in the subgraph. Node Ids in the subgraph.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(), return utils.toindex(
parent_nids.todgltensor())) _CAPI_DGLMapSubgraphNID(
induced_nodes.todgltensor(), parent_nids.todgltensor()
)
)
def transform_ids(mapping, ids): def transform_ids(mapping, ids):
"""Transform ids by the given mapping. """Transform ids by the given mapping.
...@@ -1206,8 +1249,10 @@ def transform_ids(mapping, ids): ...@@ -1206,8 +1249,10 @@ def transform_ids(mapping, ids):
utils.Index utils.Index
The new ids. The new ids.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID( return utils.toindex(
mapping.todgltensor(), ids.todgltensor())) _CAPI_DGLMapSubgraphNID(mapping.todgltensor(), ids.todgltensor())
)
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -1230,6 +1275,7 @@ def disjoint_union(graphs): ...@@ -1230,6 +1275,7 @@ def disjoint_union(graphs):
""" """
return _CAPI_DGLDisjointUnion(list(graphs)) return _CAPI_DGLDisjointUnion(list(graphs))
def disjoint_partition(graph, num_or_size_splits): def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly. """Partition the graph disjointly.
...@@ -1252,14 +1298,13 @@ def disjoint_partition(graph, num_or_size_splits): ...@@ -1252,14 +1298,13 @@ def disjoint_partition(graph, num_or_size_splits):
""" """
if isinstance(num_or_size_splits, utils.Index): if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes( rst = _CAPI_DGLDisjointPartitionBySizes(
graph, graph, num_or_size_splits.todgltensor()
num_or_size_splits.todgltensor()) )
else: else:
rst = _CAPI_DGLDisjointPartitionByNum( rst = _CAPI_DGLDisjointPartitionByNum(graph, int(num_or_size_splits))
graph,
int(num_or_size_splits))
return rst return rst
def create_graph_index(graph_data, readonly): def create_graph_index(graph_data, readonly):
"""Create a graph index object. """Create a graph index object.
...@@ -1289,11 +1334,15 @@ def create_graph_index(graph_data, readonly): ...@@ -1289,11 +1334,15 @@ def create_graph_index(graph_data, readonly):
try: try:
gidx = from_networkx(graph_data, readonly) gidx = from_networkx(graph_data, readonly)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".' raise DGLError(
% type(graph_data)) 'Error while creating graph from input of type "%s".'
% type(graph_data)
)
return gidx return gidx
def _get_halo_subgraph_inner_node(halo_subg): def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg) return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import itertools import itertools
import sys
import numpy as np import numpy as np
import scipy import scipy
from ._ffi.object import register_object, ObjectBase from . import backend as F
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from ._ffi.streams import to_dgl_stream_handle from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from .graph_index import from_coo from .graph_index import from_coo
from . import backend as F
from . import utils
@register_object('graph.HeteroGraph')
@register_object("graph.HeteroGraph")
class HeteroGraphIndex(ObjectBase): class HeteroGraphIndex(ObjectBase):
"""HeteroGraph index object. """HeteroGraph index object.
...@@ -22,6 +24,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -22,6 +24,7 @@ class HeteroGraphIndex(ObjectBase):
---- ----
Do not create GraphIndex directly. Do not create GraphIndex directly.
""" """
def __new__(cls): def __new__(cls):
obj = ObjectBase.__new__(cls) obj = ObjectBase.__new__(cls)
obj._cache = {} obj._cache = {}
...@@ -55,11 +58,19 @@ class HeteroGraphIndex(ObjectBase): ...@@ -55,11 +58,19 @@ class HeteroGraphIndex(ObjectBase):
num_src = number_of_nodes[src_ntype] num_src = number_of_nodes[src_ntype]
num_dst = number_of_nodes[dst_ntype] num_dst = number_of_nodes[dst_ntype]
src_id, dst_id, _ = edges_per_type src_id, dst_id, _ = edges_per_type
rel_graphs.append(create_unitgraph_from_coo( rel_graphs.append(
1 if src_ntype == dst_ntype else 2, num_src, num_dst, src_id, dst_id, create_unitgraph_from_coo(
['coo', 'csr', ' csc'])) 1 if src_ntype == dst_ntype else 2,
num_src,
num_dst,
src_id,
dst_id,
["coo", "csr", " csc"],
)
)
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs) _CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs
)
@property @property
def metagraph(self): def metagraph(self):
...@@ -155,7 +166,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -155,7 +166,9 @@ class HeteroGraphIndex(ObjectBase):
v : utils.Index v : utils.Index
The dst nodes. The dst nodes.
""" """
_CAPI_DGLHeteroAddEdges(self, int(etype), u.todgltensor(), v.todgltensor()) _CAPI_DGLHeteroAddEdges(
self, int(etype), u.todgltensor(), v.todgltensor()
)
self.clear_cache() self.clear_cache()
def clear(self): def clear(self):
...@@ -199,9 +212,11 @@ class HeteroGraphIndex(ObjectBase): ...@@ -199,9 +212,11 @@ class HeteroGraphIndex(ObjectBase):
The number of bits needed. The number of bits needed.
""" """
stype, dtype = self.metagraph.find_edge(etype) stype, dtype = self.metagraph.find_edge(etype)
if (self.number_of_edges(etype) >= 0x80000000 or if (
self.number_of_nodes(stype) >= 0x80000000 or self.number_of_edges(etype) >= 0x80000000
self.number_of_nodes(dtype) >= 0x80000000): or self.number_of_nodes(stype) >= 0x80000000
or self.number_of_nodes(dtype) >= 0x80000000
):
return 64 return 64
else: else:
return 32 return 32
...@@ -293,7 +308,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -293,7 +308,9 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroRecordStream(self, to_dgl_stream_handle(stream)) return _CAPI_DGLHeteroRecordStream(self, to_dgl_stream_handle(stream))
def shared_memory(self, name, ntypes=None, etypes=None, formats=('coo', 'csr', 'csc')): def shared_memory(
self, name, ntypes=None, etypes=None, formats=("coo", "csr", "csc")
):
"""Return a copy of this graph in shared memory """Return a copy of this graph in shared memory
Parameters Parameters
...@@ -318,7 +335,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -318,7 +335,9 @@ class HeteroGraphIndex(ObjectBase):
assert fmt in ("coo", "csr", "csc") assert fmt in ("coo", "csr", "csc")
ntypes = [] if ntypes is None else ntypes ntypes = [] if ntypes is None else ntypes
etypes = [] if etypes is None else etypes etypes = [] if etypes is None else etypes
return _CAPI_DGLHeteroCopyToSharedMem(self, name, ntypes, etypes, formats) return _CAPI_DGLHeteroCopyToSharedMem(
self, name, ntypes, etypes, formats
)
def is_multigraph(self): def is_multigraph(self):
"""Return whether the graph is a multigraph """Return whether the graph is a multigraph
...@@ -386,8 +405,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -386,8 +405,9 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
0-1 array indicating existence 0-1 array indicating existence
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroHasVertices( return F.from_dgl_nd(
self, int(ntype), F.to_dgl_nd(vids))) _CAPI_DGLHeteroHasVertices(self, int(ntype), F.to_dgl_nd(vids))
)
def has_edges_between(self, etype, u, v): def has_edges_between(self, etype, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
...@@ -406,8 +426,11 @@ class HeteroGraphIndex(ObjectBase): ...@@ -406,8 +426,11 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
0-1 array indicating existence 0-1 array indicating existence
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroHasEdgesBetween( return F.from_dgl_nd(
self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v))) _CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)
)
)
def predecessors(self, etype, v): def predecessors(self, etype, v):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -426,8 +449,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -426,8 +449,9 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
Array of predecessors Array of predecessors
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroPredecessors( return F.from_dgl_nd(
self, int(etype), int(v))) _CAPI_DGLHeteroPredecessors(self, int(etype), int(v))
)
def successors(self, etype, v): def successors(self, etype, v):
"""Return the successors of the node. """Return the successors of the node.
...@@ -446,8 +470,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -446,8 +470,9 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
Array of successors Array of successors
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroSuccessors( return F.from_dgl_nd(
self, int(etype), int(v))) _CAPI_DGLHeteroSuccessors(self, int(etype), int(v))
)
def edge_ids_all(self, etype, u, v): def edge_ids_all(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs. """Return a triplet of arrays that contains the edge IDs.
...@@ -471,7 +496,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -471,7 +496,8 @@ class HeteroGraphIndex(ObjectBase):
The edge ids. The edge ids.
""" """
edge_array = _CAPI_DGLHeteroEdgeIdsAll( edge_array = _CAPI_DGLHeteroEdgeIdsAll(
self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)) self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)
)
src = F.from_dgl_nd(edge_array(0)) src = F.from_dgl_nd(edge_array(0))
dst = F.from_dgl_nd(edge_array(1)) dst = F.from_dgl_nd(edge_array(1))
...@@ -496,8 +522,11 @@ class HeteroGraphIndex(ObjectBase): ...@@ -496,8 +522,11 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
The edge ids. The edge ids.
""" """
eid = F.from_dgl_nd(_CAPI_DGLHeteroEdgeIdsOne( eid = F.from_dgl_nd(
self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v))) _CAPI_DGLHeteroEdgeIdsOne(
self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)
)
)
return eid return eid
def find_edges(self, etype, eid): def find_edges(self, etype, eid):
...@@ -520,7 +549,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -520,7 +549,8 @@ class HeteroGraphIndex(ObjectBase):
The edge ids. The edge ids.
""" """
edge_array = _CAPI_DGLHeteroFindEdges( edge_array = _CAPI_DGLHeteroFindEdges(
self, int(etype), F.to_dgl_nd(eid)) self, int(etype), F.to_dgl_nd(eid)
)
src = F.from_dgl_nd(edge_array(0)) src = F.from_dgl_nd(edge_array(0))
dst = F.from_dgl_nd(edge_array(1)) dst = F.from_dgl_nd(edge_array(1))
...@@ -607,9 +637,11 @@ class HeteroGraphIndex(ObjectBase): ...@@ -607,9 +637,11 @@ class HeteroGraphIndex(ObjectBase):
""" """
if order is None: if order is None:
order = "" order = ""
elif order not in ['srcdst', 'eid']: elif order not in ["srcdst", "eid"]:
raise DGLError("Expect order to be one of None, 'srcdst', 'eid', " raise DGLError(
"got {}".format(order)) "Expect order to be one of None, 'srcdst', 'eid', "
"got {}".format(order)
)
edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order) edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)
src = F.from_dgl_nd(edge_array(0)) src = F.from_dgl_nd(edge_array(0))
dst = F.from_dgl_nd(edge_array(1)) dst = F.from_dgl_nd(edge_array(1))
...@@ -633,8 +665,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -633,8 +665,9 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
The in degree array. The in degree array.
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroInDegrees( return F.from_dgl_nd(
self, int(etype), F.to_dgl_nd(v))) _CAPI_DGLHeteroInDegrees(self, int(etype), F.to_dgl_nd(v))
)
def out_degrees(self, etype, v): def out_degrees(self, etype, v):
"""Return the out degrees of the nodes. """Return the out degrees of the nodes.
...@@ -653,8 +686,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -653,8 +686,9 @@ class HeteroGraphIndex(ObjectBase):
Tensor Tensor
The out degree array. The out degree array.
""" """
return F.from_dgl_nd(_CAPI_DGLHeteroOutDegrees( return F.from_dgl_nd(
self, int(etype), F.to_dgl_nd(v))) _CAPI_DGLHeteroOutDegrees(self, int(etype), F.to_dgl_nd(v))
)
def adjacency_matrix(self, etype, transpose, ctx): def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -683,28 +717,43 @@ class HeteroGraphIndex(ObjectBase): ...@@ -683,28 +717,43 @@ class HeteroGraphIndex(ObjectBase):
if shuffle is not required. if shuffle is not required.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
fmt = F.get_preferred_sparse_format() fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt) rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
# convert to framework-specific sparse matrix # convert to framework-specific sparse matrix
srctype, dsttype = self.metagraph.find_edge(etype) srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) nrows = (
ncols = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype) self.number_of_nodes(dsttype)
if transpose
else self.number_of_nodes(srctype)
)
ncols = (
self.number_of_nodes(srctype)
if transpose
else self.number_of_nodes(dsttype)
)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = F.copy_to(F.from_dgl_nd(rst(0)), ctx) indptr = F.copy_to(F.from_dgl_nd(rst(0)), ctx)
indices = F.copy_to(F.from_dgl_nd(rst(1)), ctx) indices = F.copy_to(F.from_dgl_nd(rst(1)), ctx)
shuffle = F.copy_to(F.from_dgl_nd(rst(2)), ctx) shuffle = F.copy_to(F.from_dgl_nd(rst(2)), ctx)
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type dat = F.ones(
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0] nnz, dtype=F.float32, ctx=ctx
) # FIXME(minjie): data type
spmat = F.sparse_matrix(
dat, ("csr", indices, indptr), (nrows, ncols)
)[0]
return spmat, shuffle return spmat, shuffle
elif fmt == "coo": elif fmt == "coo":
idx = F.copy_to(F.from_dgl_nd(rst(0)), ctx) idx = F.copy_to(F.from_dgl_nd(rst(0)), ctx)
idx = F.reshape(idx, (2, nnz)) idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx) dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix( adj, shuffle_idx = F.sparse_matrix(
dat, ('coo', idx), (nrows, ncols)) dat, ("coo", idx), (nrows, ncols)
)
return adj, shuffle_idx return adj, shuffle_idx
else: else:
raise Exception("unknown format") raise Exception("unknown format")
...@@ -743,27 +792,39 @@ class HeteroGraphIndex(ObjectBase): ...@@ -743,27 +792,39 @@ class HeteroGraphIndex(ObjectBase):
equivalent to a consecutive array from zero to the number of edges minus one. equivalent to a consecutive array from zero to the number of edges minus one.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt) rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype) srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) nrows = (
ncols = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype) self.number_of_nodes(dsttype)
if transpose
else self.number_of_nodes(srctype)
)
ncols = (
self.number_of_nodes(srctype)
if transpose
else self.number_of_nodes(dsttype)
)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = F.from_dgl_nd(rst(0)) indptr = F.from_dgl_nd(rst(0))
indices = F.from_dgl_nd(rst(1)) indices = F.from_dgl_nd(rst(1))
data = F.from_dgl_nd(rst(2)) data = F.from_dgl_nd(rst(2))
return nrows, ncols, indptr, indices, data return nrows, ncols, indptr, indices, data
elif fmt == 'coo': elif fmt == "coo":
idx = F.from_dgl_nd(rst(0)) idx = F.from_dgl_nd(rst(0))
row, col = F.reshape(idx, (2, nnz)) row, col = F.reshape(idx, (2, nnz))
return nrows, ncols, row, col return nrows, ncols, row, col
else: else:
raise ValueError("unknown format") raise ValueError("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None): def adjacency_matrix_scipy(
self, etype, transpose, fmt, return_edge_ids=None
):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -794,12 +855,14 @@ class HeteroGraphIndex(ObjectBase): ...@@ -794,12 +855,14 @@ class HeteroGraphIndex(ObjectBase):
" As a result there is one 0 entry which is not eliminated." " As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default," " In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.", " and 0 will be eliminated otherwise.",
FutureWarning) FutureWarning,
)
return_edge_ids = True return_edge_ids = True
if fmt == 'csr': if fmt == "csr":
nrows, ncols, indptr, indices, data = \ nrows, ncols, indptr, indices, data = self.adjacency_matrix_tensors(
self.adjacency_matrix_tensors(etype, transpose, fmt) etype, transpose, fmt
)
indptr = F.asnumpy(indptr) indptr = F.asnumpy(indptr)
indices = F.asnumpy(indices) indices = F.asnumpy(indices)
data = F.asnumpy(data) data = F.asnumpy(data)
...@@ -810,15 +873,23 @@ class HeteroGraphIndex(ObjectBase): ...@@ -810,15 +873,23 @@ class HeteroGraphIndex(ObjectBase):
else: else:
data = np.ones_like(indices) data = np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) return scipy.sparse.csr_matrix(
elif fmt == 'coo': (data, indices, indptr), shape=(nrows, ncols)
nrows, ncols, row, col = \ )
self.adjacency_matrix_tensors(etype, transpose, fmt) elif fmt == "coo":
nrows, ncols, row, col = self.adjacency_matrix_tensors(
etype, transpose, fmt
)
row = F.asnumpy(row) row = F.asnumpy(row)
col = F.asnumpy(col) col = F.asnumpy(col)
data = np.arange(self.number_of_edges(etype)) if return_edge_ids \ data = (
np.arange(self.number_of_edges(etype))
if return_edge_ids
else np.ones_like(row) else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols)) )
return scipy.sparse.coo_matrix(
(data, (row, col)), shape=(nrows, ncols)
)
else: else:
raise ValueError("unknown format") raise ValueError("unknown format")
...@@ -863,25 +934,26 @@ class HeteroGraphIndex(ObjectBase): ...@@ -863,25 +934,26 @@ class HeteroGraphIndex(ObjectBase):
srctype, dsttype = self.metagraph.find_edge(etype) srctype, dsttype = self.metagraph.find_edge(etype)
m = self.number_of_edges(etype) m = self.number_of_edges(etype)
if typestr == 'in': if typestr == "in":
n = self.number_of_nodes(dsttype) n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0) row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.copy_to(F.cat([row, col], dim=0), ctx) idx = F.copy_to(F.cat([row, col], dim=0), ctx)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'out': elif typestr == "out":
n = self.number_of_nodes(srctype) n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.copy_to(F.cat([row, col], dim=0), ctx) idx = F.copy_to(F.cat([row, col], dim=0), ctx)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'both': elif typestr == "both":
assert srctype == dsttype, \ assert (
"'both' is supported only if source and destination type are the same" srctype == dsttype
), "'both' is supported only if source and destination type are the same"
n = self.number_of_nodes(srctype) n = self.number_of_nodes(srctype)
# first remove entries for self loops # first remove entries for self loops
mask = F.logical_not(F.equal(src, dst)) mask = F.logical_not(F.equal(src, dst))
...@@ -897,9 +969,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -897,9 +969,9 @@ class HeteroGraphIndex(ObjectBase):
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx) x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx) y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr)) raise DGLError("Invalid incidence matrix type: %s" % str(typestr))
return inc, shuffle_idx return inc, shuffle_idx
def node_subgraph(self, induced_nodes, relabel_nodes): def node_subgraph(self, induced_nodes, relabel_nodes):
...@@ -982,7 +1054,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -982,7 +1054,9 @@ class HeteroGraphIndex(ObjectBase):
order = csr(2) order = csr(2)
rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), True, "csr") rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), True, "csr")
rev_order = rev_csr(2) rev_order = rev_csr(2)
return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype) return utils.toindex(order, self.dtype), utils.toindex(
rev_order, self.dtype
)
def formats(self, formats=None): def formats(self, formats=None):
"""Get a graph index with the specified sparse format(s) or query """Get a graph index with the specified sparse format(s) or query
...@@ -1014,16 +1088,13 @@ class HeteroGraphIndex(ObjectBase): ...@@ -1014,16 +1088,13 @@ class HeteroGraphIndex(ObjectBase):
created = [] created = []
not_created = [] not_created = []
if formats is None: if formats is None:
for fmt in ['coo', 'csr', 'csc']: for fmt in ["coo", "csr", "csc"]:
if fmt in formats_allowed: if fmt in formats_allowed:
if fmt in formats_created: if fmt in formats_created:
created.append(fmt) created.append(fmt)
else: else:
not_created.append(fmt) not_created.append(fmt)
return { return {"created": created, "not created": not_created}
'created': created,
'not created': not_created
}
else: else:
if isinstance(formats, str): if isinstance(formats, str):
formats = [formats] formats = [formats]
...@@ -1044,9 +1115,11 @@ class HeteroGraphIndex(ObjectBase): ...@@ -1044,9 +1115,11 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroReverse(self) return _CAPI_DGLHeteroReverse(self)
@register_object('graph.HeteroSubgraph')
@register_object("graph.HeteroSubgraph")
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
"""Hetero-subgraph data structure""" """Hetero-subgraph data structure"""
@property @property
def graph(self): def graph(self):
"""The subgraph structure """The subgraph structure
...@@ -1089,6 +1162,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -1089,6 +1162,7 @@ class HeteroSubgraphIndex(ObjectBase):
# Creators # Creators
################################################################# #################################################################
def create_metagraph_index(ntypes, canonical_etypes): def create_metagraph_index(ntypes, canonical_etypes):
"""Return a GraphIndex instance for a metagraph given the node types and canonical """Return a GraphIndex instance for a metagraph given the node types and canonical
edge types. edge types.
...@@ -1129,8 +1203,17 @@ def create_metagraph_index(ntypes, canonical_etypes): ...@@ -1129,8 +1203,17 @@ def create_metagraph_index(ntypes, canonical_etypes):
metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True) metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
return metagraph, ntypes, etypes, relations return metagraph, ntypes, etypes, relations
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False): def create_unitgraph_from_coo(
num_ntypes,
num_src,
num_dst,
row,
col,
formats,
row_sorted=False,
col_sorted=False,
):
"""Create a unitgraph graph index from COO format """Create a unitgraph graph index from COO format
Parameters Parameters
...@@ -1160,12 +1243,27 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, ...@@ -1160,12 +1243,27 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
if isinstance(formats, str): if isinstance(formats, str):
formats = [formats] formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCOO( return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst), int(num_ntypes),
F.to_dgl_nd(row), F.to_dgl_nd(col), int(num_src),
formats, row_sorted, col_sorted) int(num_dst),
F.to_dgl_nd(row),
def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids, F.to_dgl_nd(col),
formats, transpose=False): formats,
row_sorted,
col_sorted,
)
def create_unitgraph_from_csr(
num_ntypes,
num_src,
num_dst,
indptr,
indices,
edge_ids,
formats,
transpose=False,
):
"""Create a unitgraph graph index from CSR format """Create a unitgraph graph index from CSR format
Parameters Parameters
...@@ -1194,11 +1292,20 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg ...@@ -1194,11 +1292,20 @@ def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edg
if isinstance(formats, str): if isinstance(formats, str):
formats = [formats] formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCSR( return _CAPI_DGLHeteroCreateUnitGraphFromCSR(
int(num_ntypes), int(num_src), int(num_dst), int(num_ntypes),
F.to_dgl_nd(indptr), F.to_dgl_nd(indices), F.to_dgl_nd(edge_ids), int(num_src),
formats, transpose) int(num_dst),
F.to_dgl_nd(indptr),
def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type): F.to_dgl_nd(indices),
F.to_dgl_nd(edge_ids),
formats,
transpose,
)
def create_heterograph_from_relations(
metagraph, rel_graphs, num_nodes_per_type
):
"""Create a heterograph from metagraph and graphs of every relation. """Create a heterograph from metagraph and graphs of every relation.
Parameters Parameters
...@@ -1218,7 +1325,9 @@ def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type) ...@@ -1218,7 +1325,9 @@ def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type)
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs) return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
else: else:
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes( return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor()) metagraph, rel_graphs, num_nodes_per_type.todgltensor()
)
def create_heterograph_from_shared_memory(name): def create_heterograph_from_shared_memory(name):
"""Create a heterograph from shared memory with the given name. """Create a heterograph from shared memory with the given name.
...@@ -1239,6 +1348,7 @@ def create_heterograph_from_shared_memory(name): ...@@ -1239,6 +1348,7 @@ def create_heterograph_from_shared_memory(name):
g, ntypes, etypes = _CAPI_DGLHeteroCreateFromSharedMem(name) g, ntypes, etypes = _CAPI_DGLHeteroCreateFromSharedMem(name)
return g, list(ntypes), list(etypes) return g, list(ntypes), list(etypes)
def joint_union(metagraph, gidx_list): def joint_union(metagraph, gidx_list):
"""Return a joint union of the input heterographs. """Return a joint union of the input heterographs.
...@@ -1256,6 +1366,7 @@ def joint_union(metagraph, gidx_list): ...@@ -1256,6 +1366,7 @@ def joint_union(metagraph, gidx_list):
""" """
return _CAPI_DGLHeteroJointUnion(metagraph, gidx_list) return _CAPI_DGLHeteroJointUnion(metagraph, gidx_list)
def disjoint_union(metagraph, graphs): def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs. """Return a disjoint union of the input heterographs.
...@@ -1273,6 +1384,7 @@ def disjoint_union(metagraph, graphs): ...@@ -1273,6 +1384,7 @@ def disjoint_union(metagraph, graphs):
""" """
return _CAPI_DGLHeteroDisjointUnion_v2(metagraph, graphs) return _CAPI_DGLHeteroDisjointUnion_v2(metagraph, graphs)
def disjoint_partition(graph, bnn_all_types, bne_all_types): def disjoint_partition(graph, bnn_all_types, bne_all_types):
"""Partition the graph disjointly. """Partition the graph disjointly.
...@@ -1290,10 +1402,16 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types): ...@@ -1290,10 +1402,16 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types):
list of HeteroGraphIndex list of HeteroGraphIndex
Heterographs unbatched. Heterographs unbatched.
""" """
bnn_all_types = utils.toindex(list(itertools.chain.from_iterable(bnn_all_types))) bnn_all_types = utils.toindex(
bne_all_types = utils.toindex(list(itertools.chain.from_iterable(bne_all_types))) list(itertools.chain.from_iterable(bnn_all_types))
)
bne_all_types = utils.toindex(
list(itertools.chain.from_iterable(bne_all_types))
)
return _CAPI_DGLHeteroDisjointPartitionBySizes_v2( return _CAPI_DGLHeteroDisjointPartitionBySizes_v2(
graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor()) graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor()
)
def slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid): def slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid):
"""Slice a chunk of the graph. """Slice a chunk of the graph.
...@@ -1317,20 +1435,28 @@ def slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid): ...@@ -1317,20 +1435,28 @@ def slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid):
The sliced graph. The sliced graph.
""" """
return _CAPI_DGLHeteroSlice( return _CAPI_DGLHeteroSlice(
graph, num_nodes.todgltensor(), start_nid.todgltensor(), graph,
num_edges.todgltensor(), start_eid.todgltensor()) num_nodes.todgltensor(),
start_nid.todgltensor(),
num_edges.todgltensor(),
start_eid.todgltensor(),
)
################################################################# #################################################################
# Data structure used by C APIs # Data structure used by C APIs
################################################################# #################################################################
@register_object("graph.FlattenedHeteroGraph") @register_object("graph.FlattenedHeteroGraph")
class FlattenedHeteroGraph(ObjectBase): class FlattenedHeteroGraph(ObjectBase):
"""FlattenedHeteroGraph object class in C++ backend.""" """FlattenedHeteroGraph object class in C++ backend."""
@register_object("graph.HeteroPickleStates") @register_object("graph.HeteroPickleStates")
class HeteroPickleStates(ObjectBase): class HeteroPickleStates(ObjectBase):
"""Pickle states object class in C++ backend.""" """Pickle states object class in C++ backend."""
@property @property
def version(self): def version(self):
"""Version number """Version number
...@@ -1371,7 +1497,9 @@ class HeteroPickleStates(ObjectBase): ...@@ -1371,7 +1497,9 @@ class HeteroPickleStates(ObjectBase):
Need to set the tensor created in the __getstate__ function Need to set the tensor created in the __getstate__ function
as object attribute to avoid potential bugs as object attribute to avoid potential bugs
""" """
self._pk_arrays = [F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays] self._pk_arrays = [
F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays
]
return self.version, self.meta, self._pk_arrays return self.version, self.meta, self._pk_arrays
def __setstate__(self, state): def __setstate__(self, state):
...@@ -1379,12 +1507,18 @@ class HeteroPickleStates(ObjectBase): ...@@ -1379,12 +1507,18 @@ class HeteroPickleStates(ObjectBase):
version, meta, arrays = state version, meta, arrays = state
arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays] arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, version, meta, arrays) _CAPI_DGLCreateHeteroPickleStates, version, meta, arrays
)
else: else:
metagraph, num_nodes_per_type, adjs = state metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type) num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) _CAPI_DGLCreateHeteroPickleStatesOld,
metagraph,
num_nodes_per_type,
adjs,
)
def _forking_rebuild(pk_state): def _forking_rebuild(pk_state):
version, meta, arrays = pk_state version, meta, arrays = pk_state
...@@ -1394,6 +1528,7 @@ def _forking_rebuild(pk_state): ...@@ -1394,6 +1528,7 @@ def _forking_rebuild(pk_state):
graph_index._forking_pk_state = pk_state graph_index._forking_pk_state = pk_state
return graph_index return graph_index
def _forking_reduce(graph_index): def _forking_reduce(graph_index):
# Because F.from_dgl_nd(F.to_dgl_nd(x)) loses the information of shared memory # Because F.from_dgl_nd(F.to_dgl_nd(x)) loses the information of shared memory
# file descriptor (because DLPack does not keep it), without caching the tensors # file descriptor (because DLPack does not keep it), without caching the tensors
...@@ -1403,7 +1538,7 @@ def _forking_reduce(graph_index): ...@@ -1403,7 +1538,7 @@ def _forking_reduce(graph_index):
# should be rare though because (1) DataLoader will create all the formats if num_workers > 0 # should be rare though because (1) DataLoader will create all the formats if num_workers > 0
# anyway, and (2) we require the users to explicitly create all formats before calling # anyway, and (2) we require the users to explicitly create all formats before calling
# mp.spawn(). # mp.spawn().
if hasattr(graph_index, '_forking_pk_state'): if hasattr(graph_index, "_forking_pk_state"):
return _forking_rebuild, (graph_index._forking_pk_state,) return _forking_rebuild, (graph_index._forking_pk_state,)
states = _CAPI_DGLHeteroForkingPickle(graph_index) states = _CAPI_DGLHeteroForkingPickle(graph_index)
arrays = [F.from_dgl_nd(arr) for arr in states.arrays] arrays = [F.from_dgl_nd(arr) for arr in states.arrays]
...@@ -1415,10 +1550,11 @@ def _forking_reduce(graph_index): ...@@ -1415,10 +1550,11 @@ def _forking_reduce(graph_index):
return _forking_rebuild, (graph_index._forking_pk_state,) return _forking_rebuild, (graph_index._forking_pk_state,)
if not (F.get_preferred_backend() == 'mxnet' and sys.version_info.minor <= 6): if not (F.get_preferred_backend() == "mxnet" and sys.version_info.minor <= 6):
# Python 3.6 MXNet crashes with the following statement; remove until we no longer support # Python 3.6 MXNet crashes with the following statement; remove until we no longer support
# 3.6 (which is EOL anyway). # 3.6 (which is EOL anyway).
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
ForkingPickler.register(HeteroGraphIndex, _forking_reduce) ForkingPickler.register(HeteroGraphIndex, _forking_reduce)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
...@@ -3,9 +3,12 @@ from __future__ import absolute_import ...@@ -3,9 +3,12 @@ from __future__ import absolute_import
from . import backend as F from . import backend as F
__all__ = ['base_initializer', 'zero_initializer'] __all__ = ["base_initializer", "zero_initializer"]
def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def base_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""The function signature for feature initializer. """The function signature for feature initializer.
Any customized feature initializer should follow this signature (see Any customized feature initializer should follow this signature (see
...@@ -44,7 +47,10 @@ def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-arg ...@@ -44,7 +47,10 @@ def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-arg
""" """
raise NotImplementedError raise NotImplementedError
def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def zero_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""Zero feature initializer """Zero feature initializer
Examples Examples
......
"""Utilities for merging graphs.""" """Utilities for merging graphs."""
import dgl import dgl
from . import backend as F from . import backend as F
from .base import DGLError from .base import DGLError
__all__ = ['merge'] __all__ = ["merge"]
def merge(graphs): def merge(graphs):
r"""Merge a sequence of graphs together into a single graph. r"""Merge a sequence of graphs together into a single graph.
...@@ -62,7 +64,7 @@ def merge(graphs): ...@@ -62,7 +64,7 @@ def merge(graphs):
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError("The input list of graphs cannot be empty.")
ref = graphs[0] ref = graphs[0]
ntypes = ref.ntypes ntypes = ref.ntypes
...@@ -87,9 +89,15 @@ def merge(graphs): ...@@ -87,9 +89,15 @@ def merge(graphs):
if len(keys) == 0: if len(keys) == 0:
edges_data = None edges_data = None
else: else:
edges_data = {k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys} edges_data = {
merged_us = F.copy_to(F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device) k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys
merged_vs = F.copy_to(F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device) }
merged_us = F.copy_to(
F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device
)
merged_vs = F.copy_to(
F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device
)
merged.add_edges(merged_us, merged_vs, edges_data, etype) merged.add_edges(merged_us, merged_vs, edges_data, etype)
# Add node data and isolated nodes from next_graph to merged. # Add node data and isolated nodes from next_graph to merged.
...@@ -98,12 +106,16 @@ def merge(graphs): ...@@ -98,12 +106,16 @@ def merge(graphs):
merged_ntype_id = merged.get_ntype_id(ntype) merged_ntype_id = merged.get_ntype_id(ntype)
next_ntype_id = next_graph.get_ntype_id(ntype) next_ntype_id = next_graph.get_ntype_id(ntype)
next_ndata = next_graph._node_frames[next_ntype_id] next_ndata = next_graph._node_frames[next_ntype_id]
node_diff = (next_graph.num_nodes(ntype=ntype) - node_diff = next_graph.num_nodes(ntype=ntype) - merged.num_nodes(
merged.num_nodes(ntype=ntype)) ntype=ntype
)
n_extra_nodes = max(0, node_diff) n_extra_nodes = max(0, node_diff)
merged.add_nodes(n_extra_nodes, ntype=ntype) merged.add_nodes(n_extra_nodes, ntype=ntype)
next_nodes = F.arange( next_nodes = F.arange(
0, next_graph.num_nodes(ntype=ntype), merged.idtype, merged.device 0,
next_graph.num_nodes(ntype=ntype),
merged.idtype,
merged.device,
) )
merged._node_frames[merged_ntype_id].update_row( merged._node_frames[merged_ntype_id].update_row(
next_nodes, next_ndata next_nodes, next_ndata
......
"""dgl sparse class.""" """dgl sparse class."""
from .diag_matrix import * from .diag_matrix import *
from .sp_matrix import *
from .elementwise_op import * from .elementwise_op import *
from .sddmm import * from .matmul import *
from .reduction import * # pylint: disable=W0622 from .reduction import * # pylint: disable=W0622
from .sddmm import *
from .sp_matrix import *
from .unary_diag import * from .unary_diag import *
from .unary_sp import * from .unary_sp import *
from .matmul import *
"""DGL sparse matrix module.""" """DGL sparse matrix module."""
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
__all__ = [ __all__ = [
......
...@@ -6,9 +6,10 @@ ...@@ -6,9 +6,10 @@
# make fork() and openmp work together. # make fork() and openmp work together.
from .. import backend as F from .. import backend as F
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == "pytorch":
# Wrap around torch.multiprocessing... # Wrap around torch.multiprocessing...
from torch.multiprocessing import * from torch.multiprocessing import *
# ... and override the Process initializer. # ... and override the Process initializer.
from .pytorch import * from .pytorch import *
else: else:
......
"""PyTorch multiprocessing wrapper.""" """PyTorch multiprocessing wrapper."""
from functools import wraps
import random import random
import traceback import traceback
from _thread import start_new_thread from _thread import start_new_thread
from functools import wraps
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from ..utils import create_shared_mem_array, get_shared_mem_array from ..utils import create_shared_mem_array, get_shared_mem_array
def thread_wrapped_func(func): def thread_wrapped_func(func):
""" """
Wraps a process entry point to make it work with OpenMP. Wraps a process entry point to make it work with OpenMP.
""" """
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
queue = mp.Queue() queue = mp.Queue()
def _queue_result(): def _queue_result():
exception, trace, res = None, None, None exception, trace, res = None, None, None
try: try:
...@@ -31,18 +35,31 @@ def thread_wrapped_func(func): ...@@ -31,18 +35,31 @@ def thread_wrapped_func(func):
else: else:
assert isinstance(exception, Exception) assert isinstance(exception, Exception)
raise exception.__class__(trace) raise exception.__class__(trace)
return decorated_function return decorated_function
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
class Process(mp.Process): class Process(mp.Process):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): def __init__(
self,
group=None,
target=None,
name=None,
args=(),
kwargs={},
*,
daemon=None
):
target = thread_wrapped_func(target) target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon) super().__init__(group, target, name, args, kwargs, daemon=daemon)
def _get_shared_mem_name(id_): def _get_shared_mem_name(id_):
return "shared" + str(id_) return "shared" + str(id_)
def call_once_and_share(func, shape, dtype, rank=0): def call_once_and_share(func, shape, dtype, rank=0):
"""Invoke the function in a single process of the PyTorch distributed process group, """Invoke the function in a single process of the PyTorch distributed process group,
and share the result with other processes. and share the result with other processes.
...@@ -61,7 +78,7 @@ def call_once_and_share(func, shape, dtype, rank=0): ...@@ -61,7 +78,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
current_rank = torch.distributed.get_rank() current_rank = torch.distributed.get_rank()
dist_buf = torch.LongTensor([1]) dist_buf = torch.LongTensor([1])
if torch.distributed.get_backend() == 'nccl': if torch.distributed.get_backend() == "nccl":
# Use .cuda() to transfer it to the correct device. Should be OK since # Use .cuda() to transfer it to the correct device. Should be OK since
# PyTorch recommends the users to call set_device() after getting inside # PyTorch recommends the users to call set_device() after getting inside
# torch.multiprocessing.spawn() # torch.multiprocessing.spawn()
...@@ -88,6 +105,7 @@ def call_once_and_share(func, shape, dtype, rank=0): ...@@ -88,6 +105,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
return result return result
def shared_tensor(shape, dtype=torch.float32): def shared_tensor(shape, dtype=torch.float32):
"""Create a tensor in shared memory accessible by all processes within the same """Create a tensor in shared memory accessible by all processes within the same
``torch.distributed`` process group. ``torch.distributed`` process group.
...@@ -106,4 +124,6 @@ def shared_tensor(shape, dtype=torch.float32): ...@@ -106,4 +124,6 @@ def shared_tensor(shape, dtype=torch.float32):
Tensor Tensor
The shared tensor. The shared tensor.
""" """
return call_once_and_share(lambda: torch.empty(*shape, dtype=dtype), shape, dtype) return call_once_and_share(
lambda: torch.empty(*shape, dtype=dtype), shape, dtype
)
...@@ -9,17 +9,28 @@ from __future__ import absolute_import as _abs ...@@ -9,17 +9,28 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import functools import functools
import operator import operator
import numpy as _np import numpy as _np
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F from . import backend as F
from ._ffi.function import _init_api
from ._ffi.ndarray import (
DGLContext,
DGLDataType,
NDArrayBase,
_set_class_ndarray,
context,
empty,
empty_shared_mem,
from_dlpack,
numpyasarray,
)
from ._ffi.object import ObjectBase, register_object
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework.""" """Lightweight NDArray class for DGL framework."""
def __len__(self): def __len__(self):
return functools.reduce(operator.mul, self.shape, 1) return functools.reduce(operator.mul, self.shape, 1)
...@@ -35,7 +46,10 @@ class NDArray(NDArrayBase): ...@@ -35,7 +46,10 @@ class NDArray(NDArrayBase):
------- -------
NDArray NDArray
""" """
return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(self) return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(
self
)
def cpu(dev_id=0): def cpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -52,6 +66,7 @@ def cpu(dev_id=0): ...@@ -52,6 +66,7 @@ def cpu(dev_id=0):
""" """
return DGLContext(1, dev_id) return DGLContext(1, dev_id)
def gpu(dev_id=0): def gpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -67,6 +82,7 @@ def gpu(dev_id=0): ...@@ -67,6 +82,7 @@ def gpu(dev_id=0):
""" """
return DGLContext(2, dev_id) return DGLContext(2, dev_id)
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
"""Create an array from source arr. """Create an array from source arr.
...@@ -87,6 +103,7 @@ def array(arr, ctx=cpu(0)): ...@@ -87,6 +103,7 @@ def array(arr, ctx=cpu(0)):
arr = _np.array(arr) arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr) return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data. """Create an array that shares the given numpy data.
...@@ -104,6 +121,7 @@ def zerocopy_from_numpy(np_data): ...@@ -104,6 +121,7 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr) handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True) return NDArray(handle, is_view=True)
def cast_to_signed(arr): def cast_to_signed(arr):
"""Cast this NDArray from unsigned integer to signed one. """Cast this NDArray from unsigned integer to signed one.
...@@ -124,8 +142,9 @@ def cast_to_signed(arr): ...@@ -124,8 +142,9 @@ def cast_to_signed(arr):
""" """
return _CAPI_DGLArrayCastToSigned(arr) return _CAPI_DGLArrayCastToSigned(arr)
def get_shared_mem_array(name, shape, dtype): def get_shared_mem_array(name, shape, dtype):
""" Get a tensor from shared memory with specific name """Get a tensor from shared memory with specific name
Parameters Parameters
---------- ----------
...@@ -141,12 +160,15 @@ def get_shared_mem_array(name, shape, dtype): ...@@ -141,12 +160,15 @@ def get_shared_mem_array(name, shape, dtype):
F.tensor F.tensor
The tensor got from shared memory. The tensor got from shared memory.
""" """
new_arr = empty_shared_mem(name, False, shape, F.reverse_data_type_dict[dtype]) new_arr = empty_shared_mem(
name, False, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack() dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def create_shared_mem_array(name, shape, dtype): def create_shared_mem_array(name, shape, dtype):
""" Create a tensor from shared memory with the specific name """Create a tensor from shared memory with the specific name
Parameters Parameters
---------- ----------
...@@ -162,12 +184,15 @@ def create_shared_mem_array(name, shape, dtype): ...@@ -162,12 +184,15 @@ def create_shared_mem_array(name, shape, dtype):
F.tensor F.tensor
The created tensor. The created tensor.
""" """
new_arr = empty_shared_mem(name, True, shape, F.reverse_data_type_dict[dtype]) new_arr = empty_shared_mem(
name, True, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack() dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def exist_shared_mem_array(name): def exist_shared_mem_array(name):
""" Check the existence of shared-memory array. """Check the existence of shared-memory array.
Parameters Parameters
---------- ----------
...@@ -181,23 +206,27 @@ def exist_shared_mem_array(name): ...@@ -181,23 +206,27 @@ def exist_shared_mem_array(name):
""" """
return _CAPI_DGLExistSharedMemArray(name) return _CAPI_DGLExistSharedMemArray(name)
class SparseFormat: class SparseFormat:
"""Format code""" """Format code"""
ANY = 0 ANY = 0
COO = 1 COO = 1
CSR = 2 CSR = 2
CSC = 3 CSC = 3
FORMAT2STR = { FORMAT2STR = {
0 : 'ANY', 0: "ANY",
1 : 'COO', 1: "COO",
2 : 'CSR', 2: "CSR",
3 : 'CSC', 3: "CSC",
} }
@register_object('aten.SparseMatrix')
@register_object("aten.SparseMatrix")
class SparseMatrix(ObjectBase): class SparseMatrix(ObjectBase):
"""Sparse matrix object class in C++ backend.""" """Sparse matrix object class in C++ backend."""
@property @property
def format(self): def format(self):
"""Sparse format enum """Sparse format enum
...@@ -250,17 +279,26 @@ class SparseMatrix(ObjectBase): ...@@ -250,17 +279,26 @@ class SparseMatrix(ObjectBase):
return _CAPI_DGLSparseMatrixGetFlags(self) return _CAPI_DGLSparseMatrixGetFlags(self)
def __getstate__(self): def __getstate__(self):
return self.format, self.num_rows, self.num_cols, self.indices, self.flags return (
self.format,
self.num_rows,
self.num_cols,
self.indices,
self.flags,
)
def __setstate__(self, state): def __setstate__(self, state):
fmt, nrows, ncols, indices, flags = state fmt, nrows, ncols, indices, flags = state
indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices] indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags) _CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags
)
def __repr__(self): def __repr__(self):
return 'SparseMatrix(fmt="{}", shape=({},{}))'.format( return 'SparseMatrix(fmt="{}", shape=({},{}))'.format(
SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols) SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols
)
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
...@@ -270,5 +308,5 @@ _init_api("dgl.ndarray.uvm", __name__) ...@@ -270,5 +308,5 @@ _init_api("dgl.ndarray.uvm", __name__)
# other backend tensors. # other backend tensors.
NULL = { NULL = {
"int64": array(_np.array([], dtype=_np.int64)), "int64": array(_np.array([], dtype=_np.int64)),
"int32": array(_np.array([], dtype=_np.int32)) "int32": array(_np.array([], dtype=_np.int32)),
} }
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
from __future__ import absolute_import from __future__ import absolute_import
import time import time
from enum import Enum
from collections import namedtuple from collections import namedtuple
from enum import Enum
import dgl.backend as F import dgl.backend as F
from ._ffi.function import _init_api
from ._deprecate.nodeflow import NodeFlow
from . import utils from . import utils
from ._deprecate.nodeflow import NodeFlow
from ._ffi.function import _init_api
_init_api("dgl.network") _init_api("dgl.network")
...@@ -19,12 +20,11 @@ _WAIT_TIME_SEC = 3 # 3 seconds ...@@ -19,12 +20,11 @@ _WAIT_TIME_SEC = 3 # 3 seconds
def _network_wait(): def _network_wait():
"""Sleep for a few seconds """Sleep for a few seconds"""
"""
time.sleep(_WAIT_TIME_SEC) time.sleep(_WAIT_TIME_SEC)
def _create_sender(net_type, msg_queue_size=2*1024*1024*1024): def _create_sender(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Sender communicator via C api """Create a Sender communicator via C api
Parameters Parameters
...@@ -34,11 +34,11 @@ def _create_sender(net_type, msg_queue_size=2*1024*1024*1024): ...@@ -34,11 +34,11 @@ def _create_sender(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int msg_queue_size : int
message queue size (2GB by default) message queue size (2GB by default)
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLSenderCreate(net_type, msg_queue_size) return _CAPI_DGLSenderCreate(net_type, msg_queue_size)
def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024): def _create_receiver(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Receiver communicator via C api """Create a Receiver communicator via C api
Parameters Parameters
...@@ -48,7 +48,7 @@ def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024): ...@@ -48,7 +48,7 @@ def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int msg_queue_size : int
message queue size (2GB by default) message queue size (2GB by default)
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLReceiverCreate(net_type, msg_queue_size) return _CAPI_DGLReceiverCreate(net_type, msg_queue_size)
...@@ -64,8 +64,7 @@ def _finalize_sender(sender): ...@@ -64,8 +64,7 @@ def _finalize_sender(sender):
def _finalize_receiver(receiver): def _finalize_receiver(receiver):
"""Finalize Receiver Communicator """Finalize Receiver Communicator"""
"""
_CAPI_DGLFinalizeReceiver(receiver) _CAPI_DGLFinalizeReceiver(receiver)
...@@ -83,7 +82,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id): ...@@ -83,7 +82,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id)) _CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
...@@ -112,7 +111,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender): ...@@ -112,7 +111,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender):
num_sender : int num_sender : int
total number of Sender total number of Sender
""" """
assert num_sender >= 0, 'num_sender cannot be a negative number.' assert num_sender >= 0, "num_sender cannot be a negative number."
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender)) _CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
...@@ -131,19 +130,22 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -131,19 +130,22 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
gidx = nodeflow._graph gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor() node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor() edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor() layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor() flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendNodeFlow(sender, _CAPI_SenderSendNodeFlow(
sender,
int(recv_id), int(recv_id),
gidx, gidx,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
layers_offsets, layers_offsets,
flows_offsets) flows_offsets,
)
def _send_sampler_end_signal(sender, recv_id): def _send_sampler_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver. """Send an epoch-end signal to remote Receiver.
...@@ -155,9 +157,10 @@ def _send_sampler_end_signal(sender, recv_id): ...@@ -155,9 +157,10 @@ def _send_sampler_end_signal(sender, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_SenderSendSamplerEndSignal(sender, int(recv_id)) _CAPI_SenderSendSamplerEndSignal(sender, int(recv_id))
def _recv_nodeflow(receiver, graph): def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler. """Receive sampled subgraph (NodeFlow) from remote sampler.
...@@ -183,8 +186,8 @@ def _recv_nodeflow(receiver, graph): ...@@ -183,8 +186,8 @@ def _recv_nodeflow(receiver, graph):
class KVMsgType(Enum): class KVMsgType(Enum):
"""Type of kvstore message """Type of kvstore message"""
"""
FINAL = 1 FINAL = 1
INIT = 2 INIT = 2
PUSH = 3 PUSH = 3
...@@ -215,6 +218,7 @@ c_ptr : void* ...@@ -215,6 +218,7 @@ c_ptr : void*
c pointer of message c pointer of message
""" """
def _send_kv_msg(sender, msg, recv_id): def _send_kv_msg(sender, msg, recv_id):
"""Send kvstore message. """Send kvstore message.
...@@ -230,12 +234,8 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -230,12 +234,8 @@ def _send_kv_msg(sender, msg, recv_id):
if msg.type == KVMsgType.PULL: if msg.type == KVMsgType.PULL:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id) tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender, int(recv_id), msg.type.value, msg.rank, msg.name, tensor_id
int(recv_id), )
msg.type.value,
msg.rank,
msg.name,
tensor_id)
elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK): elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape) tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
...@@ -244,20 +244,14 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -244,20 +244,14 @@ def _send_kv_msg(sender, msg, recv_id):
msg.type.value, msg.type.value,
msg.rank, msg.rank,
msg.name, msg.name,
tensor_shape) tensor_shape,
)
elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE): elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender, int(recv_id), msg.type.value, msg.rank, msg.name
int(recv_id), )
msg.type.value,
msg.rank,
msg.name)
elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank)
sender,
int(recv_id),
msg.type.value,
msg.rank)
else: else:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id) tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
data = F.zerocopy_to_dgl_ndarray(msg.data) data = F.zerocopy_to_dgl_ndarray(msg.data)
...@@ -268,7 +262,8 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -268,7 +262,8 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank, msg.rank,
msg.name, msg.name,
tensor_id, tensor_id,
data) data,
)
def _recv_kv_msg(receiver): def _recv_kv_msg(receiver):
...@@ -288,7 +283,9 @@ def _recv_kv_msg(receiver): ...@@ -288,7 +283,9 @@ def _recv_kv_msg(receiver):
rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr) rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr)
if msg_type == KVMsgType.PULL: if msg_type == KVMsgType.PULL:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr)) tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
rank=rank, rank=rank,
...@@ -296,11 +293,14 @@ def _recv_kv_msg(receiver): ...@@ -296,11 +293,14 @@ def _recv_kv_msg(receiver):
id=tensor_id, id=tensor_id,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK): elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr)) tensor_shape = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgShape(msg_ptr)
)
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
rank=rank, rank=rank,
...@@ -308,7 +308,8 @@ def _recv_kv_msg(receiver): ...@@ -308,7 +308,8 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=tensor_shape, shape=tensor_shape,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE): elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
...@@ -319,7 +320,8 @@ def _recv_kv_msg(receiver): ...@@ -319,7 +320,8 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg( msg = KVStoreMsg(
...@@ -329,11 +331,14 @@ def _recv_kv_msg(receiver): ...@@ -329,11 +331,14 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
else: else:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr)) tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr)) data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr))
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
...@@ -342,25 +347,34 @@ def _recv_kv_msg(receiver): ...@@ -342,25 +347,34 @@ def _recv_kv_msg(receiver):
id=tensor_id, id=tensor_id,
data=data, data=data,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value) raise RuntimeError("Unknown message type: %d" % msg_type.value)
def _clear_kv_msg(msg): def _clear_kv_msg(msg):
"""Clear data of kvstore message """Clear data of kvstore message"""
"""
F.sync() F.sync()
if msg.c_ptr is not None: if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr) _CAPI_DeleteKVMsg(msg.c_ptr)
def _fast_pull(name, id_tensor, def _fast_pull(
machine_count, group_count, machine_id, client_id, name,
partition_book, g2l, local_data, id_tensor,
sender, receiver): machine_count,
""" Pull message group_count,
machine_id,
client_id,
partition_book,
g2l,
local_data,
sender,
receiver,
):
"""Pull message
Parameters Parameters
---------- ----------
...@@ -393,17 +407,33 @@ def _fast_pull(name, id_tensor, ...@@ -393,17 +407,33 @@ def _fast_pull(name, id_tensor,
target tensor target tensor
""" """
if g2l is not None: if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id, res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor), F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book), F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data), F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l', sender,
F.zerocopy_to_dgl_ndarray(g2l)) receiver,
"has_g2l",
F.zerocopy_to_dgl_ndarray(g2l),
)
else: else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id, res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor), F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book), F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data), F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l') sender,
receiver,
"no_g2l",
)
return F.zerocopy_from_dgl_ndarray(res_tensor) return F.zerocopy_from_dgl_ndarray(res_tensor)
...@@ -14,20 +14,22 @@ with "[NN] XXX module". ...@@ -14,20 +14,22 @@ with "[NN] XXX module".
""" """
import importlib import importlib
import sys
import os import os
import sys
from ..backend import backend_name
from ..utils import expand_as_pair
# [BarclayII] Not sure what's going on with pylint. # [BarclayII] Not sure what's going on with pylint.
# Possible issue: https://github.com/PyCQA/pylint/issues/2648 # Possible issue: https://github.com/PyCQA/pylint/issues/2648
from . import functional # pylint: disable=import-self from . import functional # pylint: disable=import-self
from ..backend import backend_name
from ..utils import expand_as_pair
def _load_backend(mod_name): def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items(): for api, obj in mod.__dict__.items():
setattr(thismod, api, obj) setattr(thismod, api, obj)
_load_backend(backend_name) _load_backend(backend_name)
"""MXNet modules for graph convolutions.""" """MXNet modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .tagconv import TAGConv
from .gatconv import GATConv
from .sageconv import SAGEConv
from .gatedgraphconv import GatedGraphConv
from .chebconv import ChebConv
from .agnnconv import AGNNConv from .agnnconv import AGNNConv
from .appnpconv import APPNPConv from .appnpconv import APPNPConv
from .chebconv import ChebConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv from .densesageconv import DenseSAGEConv
from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv from .edgeconv import EdgeConv
from .gatconv import GATConv
from .gatedgraphconv import GatedGraphConv
from .ginconv import GINConv from .ginconv import GINConv
from .gmmconv import GMMConv from .gmmconv import GMMConv
from .graphconv import GraphConv
from .nnconv import NNConv from .nnconv import NNConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .sgconv import SGConv from .sgconv import SGConv
from .tagconv import TAGConv
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv', 'GATConv', __all__ = [
'SAGEConv', 'GatedGraphConv', 'ChebConv', 'AGNNConv', "GraphConv",
'APPNPConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv', "TAGConv",
'EdgeConv', 'GINConv', 'GMMConv', 'NNConv', 'SGConv'] "RelGraphConv",
"GATConv",
"SAGEConv",
"GatedGraphConv",
"ChebConv",
"AGNNConv",
"APPNPConv",
"DenseGraphConv",
"DenseSAGEConv",
"DenseChebConv",
"EdgeConv",
"GINConv",
"GMMConv",
"NNConv",
"SGConv",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment