Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
"""Built-in function base class""" """Built-in function base class"""
from __future__ import absolute_import from __future__ import absolute_import
__all__ = ['BuiltinFunction', 'TargetCode'] __all__ = ["BuiltinFunction", "TargetCode"]
class TargetCode(object): class TargetCode(object):
...@@ -10,6 +10,7 @@ class TargetCode(object): ...@@ -10,6 +10,7 @@ class TargetCode(object):
Note: must be consistent with the target code definition in C++ side: Note: must be consistent with the target code definition in C++ side:
src/kernel/binary_reduce_common.h src/kernel/binary_reduce_common.h
""" """
SRC = 0 SRC = 0
DST = 1 DST = 1
EDGE = 2 EDGE = 2
...@@ -23,6 +24,7 @@ class TargetCode(object): ...@@ -23,6 +24,7 @@ class TargetCode(object):
class BuiltinFunction(object): class BuiltinFunction(object):
"""Base builtin function class.""" """Base builtin function class."""
@property @property
def name(self): def name(self):
"""Return the name of this builtin function.""" """Return the name of this builtin function."""
......
...@@ -4,16 +4,15 @@ from __future__ import absolute_import ...@@ -4,16 +4,15 @@ from __future__ import absolute_import
import sys import sys
from .base import BuiltinFunction, TargetCode
from .._deprecate.runtime import ir from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
from .base import BuiltinFunction, TargetCode
class ReduceFunction(BuiltinFunction): class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class.""" """Base builtin reduce function class."""
def _invoke(self, graph, edge_frame, out_size, edge_map=None, def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
out_map=None):
"""Symbolic computation of this builtin function to create """Symbolic computation of this builtin function to create
runtime.executor runtime.executor
""" """
...@@ -28,21 +27,28 @@ class ReduceFunction(BuiltinFunction): ...@@ -28,21 +27,28 @@ class ReduceFunction(BuiltinFunction):
class SimpleReduceFunction(ReduceFunction): class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another """Builtin reduce function that aggregates a single field into another
single field.""" single field."""
def __init__(self, name, msg_field, out_field): def __init__(self, name, msg_field, out_field):
self._name = name self._name = name
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def _invoke(self, graph, edge_frame, out_size, edge_map=None, def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
out_map=None):
"""Symbolic execution of this builtin function""" """Symbolic execution of this builtin function"""
reducer = self._name reducer = self._name
graph = var.GRAPH(graph) graph = var.GRAPH(graph)
edge_map = var.MAP(edge_map) edge_map = var.MAP(edge_map)
out_map = var.MAP(out_map) out_map = var.MAP(out_map)
edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field)) edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field))
return ir.COPY_REDUCE(reducer, graph, TargetCode.EDGE, edge_data, return ir.COPY_REDUCE(
out_size, edge_map, out_map) reducer,
graph,
TargetCode.EDGE,
edge_data,
out_size,
edge_map,
out_map,
)
@property @property
def name(self): def name(self):
...@@ -53,6 +59,7 @@ class SimpleReduceFunction(ReduceFunction): ...@@ -53,6 +59,7 @@ class SimpleReduceFunction(ReduceFunction):
# Generate all following reducer functions: # Generate all following reducer functions:
# sum, max, min, mean, prod # sum, max, min, mean, prod
def _gen_reduce_builtin(reducer): def _gen_reduce_builtin(reducer):
docstring = """Builtin reduce function that aggregates messages by {0}. docstring = """Builtin reduce function that aggregates messages by {0}.
...@@ -73,10 +80,13 @@ def _gen_reduce_builtin(reducer): ...@@ -73,10 +80,13 @@ def _gen_reduce_builtin(reducer):
>>> import torch >>> import torch
>>> def reduce_func(nodes): >>> def reduce_func(nodes):
>>> return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}} >>> return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}}
""".format(reducer) """.format(
reducer
)
def func(msg, out): def func(msg, out):
return SimpleReduceFunction(reducer, msg, out) return SimpleReduceFunction(reducer, msg, out)
func.__name__ = str(reducer) func.__name__ = str(reducer)
func.__qualname__ = str(reducer) func.__qualname__ = str(reducer)
func.__doc__ = docstring func.__doc__ = docstring
......
"""Module for various graph generator functions.""" """Module for various graph generator functions."""
from . import backend as F from . import backend as F
from . import convert from . import convert, random
from . import random
__all__ = ["rand_graph", "rand_bipartite"]
__all__ = ['rand_graph', 'rand_bipartite']
def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()): def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
"""Generate a random graph of the given number of nodes/edges and return. """Generate a random graph of the given number of nodes/edges and return.
...@@ -46,20 +46,28 @@ def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()): ...@@ -46,20 +46,28 @@ def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
ndata_schemes={} ndata_schemes={}
edata_schemes={}) edata_schemes={})
""" """
#TODO(minjie): support RNG as one of the arguments. # TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_nodes * num_nodes, num_edges, replace=False) eids = random.choice(num_nodes * num_nodes, num_edges, replace=False)
eids = F.zerocopy_to_numpy(eids) eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_nodes) rows = F.zerocopy_from_numpy(eids // num_nodes)
cols = F.zerocopy_from_numpy(eids % num_nodes) cols = F.zerocopy_from_numpy(eids % num_nodes)
rows = F.copy_to(F.astype(rows, idtype), device) rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device) cols = F.copy_to(F.astype(cols, idtype), device)
return convert.graph((rows, cols), return convert.graph(
num_nodes=num_nodes, (rows, cols), num_nodes=num_nodes, idtype=idtype, device=device
idtype=idtype, device=device) )
def rand_bipartite(utype, etype, vtype,
num_src_nodes, num_dst_nodes, num_edges, def rand_bipartite(
idtype=F.int64, device=F.cpu()): utype,
etype,
vtype,
num_src_nodes,
num_dst_nodes,
num_edges,
idtype=F.int64,
device=F.cpu(),
):
"""Generate a random uni-directional bipartite graph and return. """Generate a random uni-directional bipartite graph and return.
It uniformly chooses ``num_edges`` from all possible node pairs and form a graph. It uniformly chooses ``num_edges`` from all possible node pairs and form a graph.
...@@ -107,13 +115,18 @@ def rand_bipartite(utype, etype, vtype, ...@@ -107,13 +115,18 @@ def rand_bipartite(utype, etype, vtype,
num_edges={('user', 'buys', 'game'): 10}, num_edges={('user', 'buys', 'game'): 10},
metagraph=[('user', 'game', 'buys')]) metagraph=[('user', 'game', 'buys')])
""" """
#TODO(minjie): support RNG as one of the arguments. # TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_src_nodes * num_dst_nodes, num_edges, replace=False) eids = random.choice(
num_src_nodes * num_dst_nodes, num_edges, replace=False
)
eids = F.zerocopy_to_numpy(eids) eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_dst_nodes) rows = F.zerocopy_from_numpy(eids // num_dst_nodes)
cols = F.zerocopy_from_numpy(eids % num_dst_nodes) cols = F.zerocopy_from_numpy(eids % num_dst_nodes)
rows = F.copy_to(F.astype(rows, idtype), device) rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device) cols = F.copy_to(F.astype(cols, idtype), device)
return convert.heterograph({(utype, etype, vtype): (rows, cols)}, return convert.heterograph(
{(utype, etype, vtype): (rows, cols)},
{utype: num_src_nodes, vtype: num_dst_nodes}, {utype: num_src_nodes, vtype: num_dst_nodes},
idtype=idtype, device=device) idtype=idtype,
device=device,
)
...@@ -8,5 +8,5 @@ ...@@ -8,5 +8,5 @@
This package is experimental and the interfaces may be subject This package is experimental and the interfaces may be subject
to changes in future releases. to changes in future releases.
""" """
from .fps import *
from .edge_coarsening import * from .edge_coarsening import *
from .fps import *
"""Python interfaces to DGL farthest point sampler.""" """Python interfaces to DGL farthest point sampler."""
import numpy as np import numpy as np
from .._ffi.base import DGLError
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
from .._ffi.base import DGLError
from .._ffi.function import _init_api
def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result): def _farthest_point_sampler(
data, batch_size, sample_points, dist, start_idx, result
):
r"""Farthest Point Sampler r"""Farthest Point Sampler
Parameters Parameters
...@@ -32,14 +35,19 @@ def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, re ...@@ -32,14 +35,19 @@ def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, re
assert F.shape(data)[0] >= sample_points * batch_size assert F.shape(data)[0] >= sample_points * batch_size
assert F.shape(data)[0] % batch_size == 0 assert F.shape(data)[0] % batch_size == 0
_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data), _CAPI_FarthestPointSampler(
batch_size, sample_points, F.zerocopy_to_dgl_ndarray(data),
batch_size,
sample_points,
F.zerocopy_to_dgl_ndarray(dist), F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx), F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result)) F.zerocopy_to_dgl_ndarray(result),
)
def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True): def _neighbor_matching(
graph_idx, num_nodes, edge_weights=None, relabel_idx=True
):
""" """
Description Description
----------- -----------
...@@ -82,7 +90,11 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True ...@@ -82,7 +90,11 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
if edge_weights is not None: if edge_weights is not None:
edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights) edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights)
node_label = F.full_1d( node_label = F.full_1d(
num_nodes, -1, getattr(F, graph_idx.dtype), F.to_backend_ctx(graph_idx.ctx)) num_nodes,
-1,
getattr(F, graph_idx.dtype),
F.to_backend_ctx(graph_idx.ctx),
)
node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label) node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label)
_CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi) _CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi)
if F.reduce_sum(node_label < 0).item() != 0: if F.reduce_sum(node_label < 0).item() != 0:
...@@ -99,4 +111,4 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True ...@@ -99,4 +111,4 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
return node_label return node_label
_init_api('dgl.geometry', __name__) _init_api("dgl.geometry", __name__)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
from .. import remove_self_loop from .. import remove_self_loop
from .capi import _neighbor_matching from .capi import _neighbor_matching
__all__ = ['neighbor_matching'] __all__ = ["neighbor_matching"]
def neighbor_matching(graph, e_weights=None, relabel_idx=True): def neighbor_matching(graph, e_weights=None, relabel_idx=True):
r""" r"""
...@@ -48,13 +49,16 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True): ...@@ -48,13 +49,16 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True):
>>> res = neighbor_matching(g) >>> res = neighbor_matching(g)
tensor([0, 1, 1]) tensor([0, 1, 1])
""" """
assert graph.is_homogeneous, \ assert (
"The graph used in graph node matching must be homogeneous" graph.is_homogeneous
), "The graph used in graph node matching must be homogeneous"
if e_weights is not None: if e_weights is not None:
graph.edata['e_weights'] = e_weights graph.edata["e_weights"] = e_weights
graph = remove_self_loop(graph) graph = remove_self_loop(graph)
e_weights = graph.edata['e_weights'] e_weights = graph.edata["e_weights"]
graph.edata.pop('e_weights') graph.edata.pop("e_weights")
else: else:
graph = remove_self_loop(graph) graph = remove_self_loop(graph)
return _neighbor_matching(graph._graph, graph.num_nodes(), e_weights, relabel_idx) return _neighbor_matching(
graph._graph, graph.num_nodes(), e_weights, relabel_idx
)
"""Farthest Point Sampler for pytorch Geometry package""" """Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from .capi import _farthest_point_sampler from .capi import _farthest_point_sampler
__all__ = ['farthest_point_sampler'] __all__ = ["farthest_point_sampler"]
def farthest_point_sampler(pos, npoints, start_idx=None): def farthest_point_sampler(pos, npoints, start_idx=None):
...@@ -50,12 +49,16 @@ def farthest_point_sampler(pos, npoints, start_idx=None): ...@@ -50,12 +49,16 @@ def farthest_point_sampler(pos, npoints, start_idx=None):
pos = pos.reshape(-1, C) pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx) dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None: if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64, start_idx = F.randint(
ctx=ctx, low=0, high=N-1) shape=(B,), dtype=F.int64, ctx=ctx, low=0, high=N - 1
)
else: else:
if start_idx >= N or start_idx < 0: if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format( raise DGLError(
N, start_idx)) "Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx
)
)
start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx) start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx) result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result) _farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
......
"""Module for global configuration operators.""" """Module for global configuration operators."""
from ._ffi.function import _init_api from ._ffi.function import _init_api
__all__ = ["is_libxsmm_enabled", "use_libxsmm"] __all__ = ["is_libxsmm_enabled", "use_libxsmm"]
def use_libxsmm(flag): def use_libxsmm(flag):
r"""Set whether DGL uses libxsmm at runtime. r"""Set whether DGL uses libxsmm at runtime.
...@@ -21,6 +21,7 @@ def use_libxsmm(flag): ...@@ -21,6 +21,7 @@ def use_libxsmm(flag):
""" """
_CAPI_DGLConfigSetLibxsmm(flag) _CAPI_DGLConfigSetLibxsmm(flag)
def is_libxsmm_enabled(): def is_libxsmm_enabled():
r"""Get whether the use_libxsmm flag is turned on. r"""Get whether the use_libxsmm flag is turned on.
...@@ -35,4 +36,5 @@ def is_libxsmm_enabled(): ...@@ -35,4 +36,5 @@ def is_libxsmm_enabled():
""" """
return _CAPI_DGLConfigGetLibxsmm() return _CAPI_DGLConfigGetLibxsmm()
_init_api("dgl.global_config") _init_api("dgl.global_config")
"""Module for graph index class definition.""" """Module for graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
import scipy import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import utils from . import utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning
class BoolFlag(object): class BoolFlag(object):
"""Bool flag with unknown value""" """Bool flag with unknown value"""
BOOL_UNKNOWN = -1 BOOL_UNKNOWN = -1
BOOL_FALSE = 0 BOOL_FALSE = 0
BOOL_TRUE = 1 BOOL_TRUE = 1
@register_object('graph.Graph')
@register_object("graph.Graph")
class GraphIndex(ObjectBase): class GraphIndex(ObjectBase):
"""Graph index object. """Graph index object.
...@@ -33,6 +36,7 @@ class GraphIndex(ObjectBase): ...@@ -33,6 +36,7 @@ class GraphIndex(ObjectBase):
- `dgl.graph_index.from_csr` - `dgl.graph_index.from_csr`
- `dgl.graph_index.from_coo` - `dgl.graph_index.from_coo`
""" """
def __new__(cls): def __new__(cls):
obj = ObjectBase.__new__(cls) obj = ObjectBase.__new__(cls)
obj._readonly = None # python-side cache of the flag obj._readonly = None # python-side cache of the flag
...@@ -53,13 +57,15 @@ class GraphIndex(ObjectBase): ...@@ -53,13 +57,15 @@ class GraphIndex(ObjectBase):
# Pickle compatibility check # Pickle compatibility check
# TODO: we should store a storage version number in later releases. # TODO: we should store a storage version number in later releases.
if isinstance(state, tuple) and len(state) == 5: if isinstance(state, tuple) and len(state) == 5:
dgl_warning("The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3") dgl_warning(
"The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3"
)
num_nodes, _, readonly, src, dst = state num_nodes, _, readonly, src, dst = state
elif isinstance(state, tuple) and len(state) == 4: elif isinstance(state, tuple) and len(state) == 4:
# post-0.4.3. # post-0.4.3.
num_nodes, readonly, src, dst = state num_nodes, readonly, src, dst = state
else: else:
raise IOError('Unrecognized storage format.') raise IOError("Unrecognized storage format.")
self._cache = {} self._cache = {}
self._readonly = readonly self._readonly = readonly
...@@ -68,7 +74,8 @@ class GraphIndex(ObjectBase): ...@@ -68,7 +74,8 @@ class GraphIndex(ObjectBase):
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(num_nodes), int(num_nodes),
readonly) readonly,
)
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. """Add nodes.
...@@ -240,7 +247,9 @@ class GraphIndex(ObjectBase): ...@@ -240,7 +247,9 @@ class GraphIndex(ObjectBase):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)) return utils.toindex(
_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)
)
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -257,8 +266,9 @@ class GraphIndex(ObjectBase): ...@@ -257,8 +266,9 @@ class GraphIndex(ObjectBase):
utils.Index utils.Index
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLGraphPredecessors( return utils.toindex(
self, int(v), int(radius))) _CAPI_DGLGraphPredecessors(self, int(v), int(radius))
)
def successors(self, v, radius=1): def successors(self, v, radius=1):
"""Return the successors of the node. """Return the successors of the node.
...@@ -275,8 +285,9 @@ class GraphIndex(ObjectBase): ...@@ -275,8 +285,9 @@ class GraphIndex(ObjectBase):
utils.Index utils.Index
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLGraphSuccessors( return utils.toindex(
self, int(v), int(radius))) _CAPI_DGLGraphSuccessors(self, int(v), int(radius))
)
def edge_id(self, u, v): def edge_id(self, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -432,7 +443,7 @@ class GraphIndex(ObjectBase): ...@@ -432,7 +443,7 @@ class GraphIndex(ObjectBase):
""" """
_CAPI_DGLSortAdj(self) _CAPI_DGLSortAdj(self)
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache="_cache", prefix="edges")
def edges(self, order=None): def edges(self, order=None):
"""Return all the edges """Return all the edges
...@@ -606,7 +617,7 @@ class GraphIndex(ObjectBase): ...@@ -606,7 +617,7 @@ class GraphIndex(ObjectBase):
e_array = e.todgltensor() e_array = e.todgltensor()
return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes) return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache="_cache", prefix="scipy_adj")
def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None): def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
...@@ -631,8 +642,10 @@ class GraphIndex(ObjectBase): ...@@ -631,8 +642,10 @@ class GraphIndex(ObjectBase):
The scipy representation of adjacency matrix. The scipy representation of adjacency matrix.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
if return_edge_ids is None: if return_edge_ids is None:
dgl_warning( dgl_warning(
...@@ -640,17 +653,24 @@ class GraphIndex(ObjectBase): ...@@ -640,17 +653,24 @@ class GraphIndex(ObjectBase):
" As a result there is one 0 entry which is not eliminated." " As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default," " In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.", " and 0 will be eliminated otherwise.",
FutureWarning) FutureWarning,
)
return_edge_ids = True return_edge_ids = True
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) data = (
utils.toindex(rst(2)).tonumpy()
if return_edge_ids
else np.ones_like(indices)
)
n = self.number_of_nodes() n = self.number_of_nodes()
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(n, n)) return scipy.sparse.csr_matrix(
elif fmt == 'coo': (data, indices, indptr), shape=(n, n)
)
elif fmt == "coo":
idx = utils.toindex(rst(0)).tonumpy() idx = utils.toindex(rst(0)).tonumpy()
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
...@@ -660,7 +680,7 @@ class GraphIndex(ObjectBase): ...@@ -660,7 +680,7 @@ class GraphIndex(ObjectBase):
else: else:
raise Exception("unknown format") raise Exception("unknown format")
@utils.cached_member(cache='_cache', prefix='immu_gidx') @utils.cached_member(cache="_cache", prefix="immu_gidx")
def get_immutable_gidx(self, ctx): def get_immutable_gidx(self, ctx):
"""Create an immutable graph index and copy to the given device context. """Create an immutable graph index and copy to the given device context.
...@@ -717,8 +737,10 @@ class GraphIndex(ObjectBase): ...@@ -717,8 +737,10 @@ class GraphIndex(ObjectBase):
if shuffle is not required. if shuffle is not required.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError(
' but got %s.' % (type(transpose))) 'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
fmt = F.get_preferred_sparse_format() fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
...@@ -726,8 +748,11 @@ class GraphIndex(ObjectBase): ...@@ -726,8 +748,11 @@ class GraphIndex(ObjectBase):
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx) indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2)) shuffle = utils.toindex(rst(2))
dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx) dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), spmat = F.sparse_matrix(
(self.number_of_nodes(), self.number_of_nodes()))[0] dat,
("csr", indices, indptr),
(self.number_of_nodes(), self.number_of_nodes()),
)[0]
return spmat, shuffle return spmat, shuffle
elif fmt == "coo": elif fmt == "coo":
## FIXME(minjie): data type ## FIXME(minjie): data type
...@@ -736,8 +761,10 @@ class GraphIndex(ObjectBase): ...@@ -736,8 +761,10 @@ class GraphIndex(ObjectBase):
idx = F.reshape(idx, (2, m)) idx = F.reshape(idx, (2, m))
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
n = self.number_of_nodes() n = self.number_of_nodes()
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n)) adj, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, n))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return adj, shuffle_idx return adj, shuffle_idx
else: else:
raise Exception("unknown format") raise Exception("unknown format")
...@@ -783,21 +810,21 @@ class GraphIndex(ObjectBase): ...@@ -783,21 +810,21 @@ class GraphIndex(ObjectBase):
eid = eid.tousertensor(ctx) # the index of the ctx will be cached eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
if typestr == 'in': if typestr == "in":
row = F.unsqueeze(dst, 0) row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'out': elif typestr == "out":
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == 'both': elif typestr == "both":
# first remove entries for self loops # first remove entries for self loops
mask = F.logical_not(F.equal(src, dst)) mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask) src = F.boolean_mask(src, mask)
...@@ -812,10 +839,12 @@ class GraphIndex(ObjectBase): ...@@ -812,10 +839,12 @@ class GraphIndex(ObjectBase):
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx) x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx) y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr)) raise DGLError("Invalid incidence matrix type: %s" % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return inc, shuffle_idx return inc, shuffle_idx
def to_networkx(self): def to_networkx(self):
...@@ -902,7 +931,9 @@ class GraphIndex(ObjectBase): ...@@ -902,7 +931,9 @@ class GraphIndex(ObjectBase):
GraphIndex GraphIndex
The graph index on the given device context. The graph index on the given device context.
""" """
return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id) return _CAPI_DGLImmutableGraphCopyTo(
self, ctx.device_type, ctx.device_id
)
def copyto_shared_mem(self, shared_mem_name): def copyto_shared_mem(self, shared_mem_name):
"""Copy this immutable graph index to shared memory. """Copy this immutable graph index to shared memory.
...@@ -939,7 +970,10 @@ class GraphIndex(ObjectBase): ...@@ -939,7 +970,10 @@ class GraphIndex(ObjectBase):
int int
The number of bits needed The number of bits needed
""" """
if self.number_of_edges() >= 0x80000000 or self.number_of_nodes() >= 0x80000000: if (
self.number_of_edges() >= 0x80000000
or self.number_of_nodes() >= 0x80000000
):
return 64 return 64
else: else:
return 32 return 32
...@@ -961,9 +995,11 @@ class GraphIndex(ObjectBase): ...@@ -961,9 +995,11 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits)) return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
@register_object('graph.Subgraph')
@register_object("graph.Subgraph")
class SubgraphIndex(ObjectBase): class SubgraphIndex(ObjectBase):
"""Subgraph data structure""" """Subgraph data structure"""
@property @property
def graph(self): def graph(self):
"""The subgraph structure """The subgraph structure
...@@ -1028,16 +1064,15 @@ def from_coo(num_nodes, src, dst, readonly): ...@@ -1028,16 +1064,15 @@ def from_coo(num_nodes, src, dst, readonly):
dst = utils.toindex(dst) dst = utils.toindex(dst)
if readonly: if readonly:
gidx = _CAPI_DGLGraphCreate( gidx = _CAPI_DGLGraphCreate(
src.todgltensor(), src.todgltensor(), dst.todgltensor(), int(num_nodes), readonly
dst.todgltensor(), )
int(num_nodes),
readonly)
else: else:
gidx = _CAPI_DGLGraphCreateMutable() gidx = _CAPI_DGLGraphCreateMutable()
gidx.add_nodes(num_nodes) gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst) gidx.add_edges(src, dst)
return gidx return gidx
def from_csr(indptr, indices, direction): def from_csr(indptr, indices, direction):
"""Load a graph from CSR arrays. """Load a graph from CSR arrays.
...@@ -1058,11 +1093,11 @@ def from_csr(indptr, indices, direction): ...@@ -1058,11 +1093,11 @@ def from_csr(indptr, indices, direction):
indptr = utils.toindex(indptr) indptr = utils.toindex(indptr)
indices = utils.toindex(indices) indices = utils.toindex(indices)
gidx = _CAPI_DGLGraphCSRCreate( gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(), indices.todgltensor(), direction
indices.todgltensor(), )
direction)
return gidx return gidx
def from_shared_mem_graph_index(shared_mem_name): def from_shared_mem_graph_index(shared_mem_name):
"""Load a graph index from the shared memory. """Load a graph index from the shared memory.
...@@ -1078,6 +1113,7 @@ def from_shared_mem_graph_index(shared_mem_name): ...@@ -1078,6 +1113,7 @@ def from_shared_mem_graph_index(shared_mem_name):
""" """
return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name) return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)
def from_networkx(nx_graph, readonly): def from_networkx(nx_graph, readonly):
"""Convert from networkx graph. """Convert from networkx graph.
...@@ -1107,7 +1143,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1107,7 +1143,7 @@ def from_networkx(nx_graph, readonly):
# nx_graph.edges(data=True) returns src, dst, attr_dict # nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0: if nx_graph.number_of_edges() > 0:
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1] has_edge_id = "id" in next(iter(nx_graph.edges(data=True)))[-1]
else: else:
has_edge_id = False has_edge_id = False
...@@ -1116,7 +1152,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1116,7 +1152,7 @@ def from_networkx(nx_graph, readonly):
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr['id'] eid = attr["id"]
src[eid] = u src[eid] = u
dst[eid] = v dst[eid] = v
else: else:
...@@ -1131,6 +1167,7 @@ def from_networkx(nx_graph, readonly): ...@@ -1131,6 +1167,7 @@ def from_networkx(nx_graph, readonly):
dst = utils.toindex(dst) dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, readonly) return from_coo(num_nodes, src, dst, readonly)
def from_scipy_sparse_matrix(adj, readonly): def from_scipy_sparse_matrix(adj, readonly):
"""Convert from scipy sparse matrix. """Convert from scipy sparse matrix.
...@@ -1145,7 +1182,7 @@ def from_scipy_sparse_matrix(adj, readonly): ...@@ -1145,7 +1182,7 @@ def from_scipy_sparse_matrix(adj, readonly):
GraphIndex GraphIndex
The graph index. The graph index.
""" """
if adj.getformat() != 'csr' or not readonly: if adj.getformat() != "csr" or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1]) num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo() adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly) return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)
...@@ -1153,6 +1190,7 @@ def from_scipy_sparse_matrix(adj, readonly): ...@@ -1153,6 +1190,7 @@ def from_scipy_sparse_matrix(adj, readonly):
# If the input matrix is csr, we still treat it as multigraph. # If the input matrix is csr, we still treat it as multigraph.
return from_csr(adj.indptr, adj.indices, "out") return from_csr(adj.indptr, adj.indices, "out")
def from_edge_list(elist, readonly): def from_edge_list(elist, readonly):
"""Convert from an edge list. """Convert from an edge list.
...@@ -1172,6 +1210,7 @@ def from_edge_list(elist, readonly): ...@@ -1172,6 +1210,7 @@ def from_edge_list(elist, readonly):
num_nodes = max(src.max(), dst.max()) + 1 num_nodes = max(src.max(), dst.max()) + 1
return from_coo(num_nodes, src_ids, dst_ids, readonly) return from_coo(num_nodes, src_ids, dst_ids, readonly)
def map_to_subgraph_nid(induced_nodes, parent_nids): def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
...@@ -1188,8 +1227,12 @@ def map_to_subgraph_nid(induced_nodes, parent_nids): ...@@ -1188,8 +1227,12 @@ def map_to_subgraph_nid(induced_nodes, parent_nids):
utils.Index utils.Index
Node Ids in the subgraph. Node Ids in the subgraph.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(), return utils.toindex(
parent_nids.todgltensor())) _CAPI_DGLMapSubgraphNID(
induced_nodes.todgltensor(), parent_nids.todgltensor()
)
)
def transform_ids(mapping, ids): def transform_ids(mapping, ids):
"""Transform ids by the given mapping. """Transform ids by the given mapping.
...@@ -1206,8 +1249,10 @@ def transform_ids(mapping, ids): ...@@ -1206,8 +1249,10 @@ def transform_ids(mapping, ids):
utils.Index utils.Index
The new ids. The new ids.
""" """
return utils.toindex(_CAPI_DGLMapSubgraphNID( return utils.toindex(
mapping.todgltensor(), ids.todgltensor())) _CAPI_DGLMapSubgraphNID(mapping.todgltensor(), ids.todgltensor())
)
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -1230,6 +1275,7 @@ def disjoint_union(graphs): ...@@ -1230,6 +1275,7 @@ def disjoint_union(graphs):
""" """
return _CAPI_DGLDisjointUnion(list(graphs)) return _CAPI_DGLDisjointUnion(list(graphs))
def disjoint_partition(graph, num_or_size_splits): def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly. """Partition the graph disjointly.
...@@ -1252,14 +1298,13 @@ def disjoint_partition(graph, num_or_size_splits): ...@@ -1252,14 +1298,13 @@ def disjoint_partition(graph, num_or_size_splits):
""" """
if isinstance(num_or_size_splits, utils.Index): if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes( rst = _CAPI_DGLDisjointPartitionBySizes(
graph, graph, num_or_size_splits.todgltensor()
num_or_size_splits.todgltensor()) )
else: else:
rst = _CAPI_DGLDisjointPartitionByNum( rst = _CAPI_DGLDisjointPartitionByNum(graph, int(num_or_size_splits))
graph,
int(num_or_size_splits))
return rst return rst
def create_graph_index(graph_data, readonly): def create_graph_index(graph_data, readonly):
"""Create a graph index object. """Create a graph index object.
...@@ -1289,11 +1334,15 @@ def create_graph_index(graph_data, readonly): ...@@ -1289,11 +1334,15 @@ def create_graph_index(graph_data, readonly):
try: try:
gidx = from_networkx(graph_data, readonly) gidx = from_networkx(graph_data, readonly)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".' raise DGLError(
% type(graph_data)) 'Error while creating graph from input of type "%s".'
% type(graph_data)
)
return gidx return gidx
def _get_halo_subgraph_inner_node(halo_subg): def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg) return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
This diff is collapsed.
...@@ -3,9 +3,12 @@ from __future__ import absolute_import ...@@ -3,9 +3,12 @@ from __future__ import absolute_import
from . import backend as F from . import backend as F
__all__ = ['base_initializer', 'zero_initializer'] __all__ = ["base_initializer", "zero_initializer"]
def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def base_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""The function signature for feature initializer. """The function signature for feature initializer.
Any customized feature initializer should follow this signature (see Any customized feature initializer should follow this signature (see
...@@ -44,7 +47,10 @@ def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-arg ...@@ -44,7 +47,10 @@ def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-arg
""" """
raise NotImplementedError raise NotImplementedError
def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def zero_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""Zero feature initializer """Zero feature initializer
Examples Examples
......
"""Utilities for merging graphs.""" """Utilities for merging graphs."""
import dgl import dgl
from . import backend as F from . import backend as F
from .base import DGLError from .base import DGLError
__all__ = ['merge'] __all__ = ["merge"]
def merge(graphs): def merge(graphs):
r"""Merge a sequence of graphs together into a single graph. r"""Merge a sequence of graphs together into a single graph.
...@@ -62,7 +64,7 @@ def merge(graphs): ...@@ -62,7 +64,7 @@ def merge(graphs):
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError("The input list of graphs cannot be empty.")
ref = graphs[0] ref = graphs[0]
ntypes = ref.ntypes ntypes = ref.ntypes
...@@ -87,9 +89,15 @@ def merge(graphs): ...@@ -87,9 +89,15 @@ def merge(graphs):
if len(keys) == 0: if len(keys) == 0:
edges_data = None edges_data = None
else: else:
edges_data = {k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys} edges_data = {
merged_us = F.copy_to(F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device) k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys
merged_vs = F.copy_to(F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device) }
merged_us = F.copy_to(
F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device
)
merged_vs = F.copy_to(
F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device
)
merged.add_edges(merged_us, merged_vs, edges_data, etype) merged.add_edges(merged_us, merged_vs, edges_data, etype)
# Add node data and isolated nodes from next_graph to merged. # Add node data and isolated nodes from next_graph to merged.
...@@ -98,12 +106,16 @@ def merge(graphs): ...@@ -98,12 +106,16 @@ def merge(graphs):
merged_ntype_id = merged.get_ntype_id(ntype) merged_ntype_id = merged.get_ntype_id(ntype)
next_ntype_id = next_graph.get_ntype_id(ntype) next_ntype_id = next_graph.get_ntype_id(ntype)
next_ndata = next_graph._node_frames[next_ntype_id] next_ndata = next_graph._node_frames[next_ntype_id]
node_diff = (next_graph.num_nodes(ntype=ntype) - node_diff = next_graph.num_nodes(ntype=ntype) - merged.num_nodes(
merged.num_nodes(ntype=ntype)) ntype=ntype
)
n_extra_nodes = max(0, node_diff) n_extra_nodes = max(0, node_diff)
merged.add_nodes(n_extra_nodes, ntype=ntype) merged.add_nodes(n_extra_nodes, ntype=ntype)
next_nodes = F.arange( next_nodes = F.arange(
0, next_graph.num_nodes(ntype=ntype), merged.idtype, merged.device 0,
next_graph.num_nodes(ntype=ntype),
merged.idtype,
merged.device,
) )
merged._node_frames[merged_ntype_id].update_row( merged._node_frames[merged_ntype_id].update_row(
next_nodes, next_ndata next_nodes, next_ndata
......
"""dgl sparse class.""" """dgl sparse class."""
from .diag_matrix import * from .diag_matrix import *
from .sp_matrix import *
from .elementwise_op import * from .elementwise_op import *
from .sddmm import * from .matmul import *
from .reduction import * # pylint: disable=W0622 from .reduction import * # pylint: disable=W0622
from .sddmm import *
from .sp_matrix import *
from .unary_diag import * from .unary_diag import *
from .unary_sp import * from .unary_sp import *
from .matmul import *
"""DGL sparse matrix module.""" """DGL sparse matrix module."""
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
__all__ = [ __all__ = [
......
...@@ -6,9 +6,10 @@ ...@@ -6,9 +6,10 @@
# make fork() and openmp work together. # make fork() and openmp work together.
from .. import backend as F from .. import backend as F
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == "pytorch":
# Wrap around torch.multiprocessing... # Wrap around torch.multiprocessing...
from torch.multiprocessing import * from torch.multiprocessing import *
# ... and override the Process initializer. # ... and override the Process initializer.
from .pytorch import * from .pytorch import *
else: else:
......
"""PyTorch multiprocessing wrapper.""" """PyTorch multiprocessing wrapper."""
from functools import wraps
import random import random
import traceback import traceback
from _thread import start_new_thread from _thread import start_new_thread
from functools import wraps
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from ..utils import create_shared_mem_array, get_shared_mem_array from ..utils import create_shared_mem_array, get_shared_mem_array
def thread_wrapped_func(func): def thread_wrapped_func(func):
""" """
Wraps a process entry point to make it work with OpenMP. Wraps a process entry point to make it work with OpenMP.
""" """
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
queue = mp.Queue() queue = mp.Queue()
def _queue_result(): def _queue_result():
exception, trace, res = None, None, None exception, trace, res = None, None, None
try: try:
...@@ -31,18 +35,31 @@ def thread_wrapped_func(func): ...@@ -31,18 +35,31 @@ def thread_wrapped_func(func):
else: else:
assert isinstance(exception, Exception) assert isinstance(exception, Exception)
raise exception.__class__(trace) raise exception.__class__(trace)
return decorated_function return decorated_function
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
class Process(mp.Process): class Process(mp.Process):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): def __init__(
self,
group=None,
target=None,
name=None,
args=(),
kwargs={},
*,
daemon=None
):
target = thread_wrapped_func(target) target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon) super().__init__(group, target, name, args, kwargs, daemon=daemon)
def _get_shared_mem_name(id_): def _get_shared_mem_name(id_):
return "shared" + str(id_) return "shared" + str(id_)
def call_once_and_share(func, shape, dtype, rank=0): def call_once_and_share(func, shape, dtype, rank=0):
"""Invoke the function in a single process of the PyTorch distributed process group, """Invoke the function in a single process of the PyTorch distributed process group,
and share the result with other processes. and share the result with other processes.
...@@ -61,7 +78,7 @@ def call_once_and_share(func, shape, dtype, rank=0): ...@@ -61,7 +78,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
current_rank = torch.distributed.get_rank() current_rank = torch.distributed.get_rank()
dist_buf = torch.LongTensor([1]) dist_buf = torch.LongTensor([1])
if torch.distributed.get_backend() == 'nccl': if torch.distributed.get_backend() == "nccl":
# Use .cuda() to transfer it to the correct device. Should be OK since # Use .cuda() to transfer it to the correct device. Should be OK since
# PyTorch recommends the users to call set_device() after getting inside # PyTorch recommends the users to call set_device() after getting inside
# torch.multiprocessing.spawn() # torch.multiprocessing.spawn()
...@@ -88,6 +105,7 @@ def call_once_and_share(func, shape, dtype, rank=0): ...@@ -88,6 +105,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
return result return result
def shared_tensor(shape, dtype=torch.float32): def shared_tensor(shape, dtype=torch.float32):
"""Create a tensor in shared memory accessible by all processes within the same """Create a tensor in shared memory accessible by all processes within the same
``torch.distributed`` process group. ``torch.distributed`` process group.
...@@ -106,4 +124,6 @@ def shared_tensor(shape, dtype=torch.float32): ...@@ -106,4 +124,6 @@ def shared_tensor(shape, dtype=torch.float32):
Tensor Tensor
The shared tensor. The shared tensor.
""" """
return call_once_and_share(lambda: torch.empty(*shape, dtype=dtype), shape, dtype) return call_once_and_share(
lambda: torch.empty(*shape, dtype=dtype), shape, dtype
)
...@@ -9,17 +9,28 @@ from __future__ import absolute_import as _abs ...@@ -9,17 +9,28 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import functools import functools
import operator import operator
import numpy as _np import numpy as _np
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F from . import backend as F
from ._ffi.function import _init_api
from ._ffi.ndarray import (
DGLContext,
DGLDataType,
NDArrayBase,
_set_class_ndarray,
context,
empty,
empty_shared_mem,
from_dlpack,
numpyasarray,
)
from ._ffi.object import ObjectBase, register_object
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework.""" """Lightweight NDArray class for DGL framework."""
def __len__(self): def __len__(self):
return functools.reduce(operator.mul, self.shape, 1) return functools.reduce(operator.mul, self.shape, 1)
...@@ -35,7 +46,10 @@ class NDArray(NDArrayBase): ...@@ -35,7 +46,10 @@ class NDArray(NDArrayBase):
------- -------
NDArray NDArray
""" """
return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(self) return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(
self
)
def cpu(dev_id=0): def cpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -52,6 +66,7 @@ def cpu(dev_id=0): ...@@ -52,6 +66,7 @@ def cpu(dev_id=0):
""" """
return DGLContext(1, dev_id) return DGLContext(1, dev_id)
def gpu(dev_id=0): def gpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -67,6 +82,7 @@ def gpu(dev_id=0): ...@@ -67,6 +82,7 @@ def gpu(dev_id=0):
""" """
return DGLContext(2, dev_id) return DGLContext(2, dev_id)
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
"""Create an array from source arr. """Create an array from source arr.
...@@ -87,6 +103,7 @@ def array(arr, ctx=cpu(0)): ...@@ -87,6 +103,7 @@ def array(arr, ctx=cpu(0)):
arr = _np.array(arr) arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr) return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data. """Create an array that shares the given numpy data.
...@@ -104,6 +121,7 @@ def zerocopy_from_numpy(np_data): ...@@ -104,6 +121,7 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr) handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True) return NDArray(handle, is_view=True)
def cast_to_signed(arr): def cast_to_signed(arr):
"""Cast this NDArray from unsigned integer to signed one. """Cast this NDArray from unsigned integer to signed one.
...@@ -124,8 +142,9 @@ def cast_to_signed(arr): ...@@ -124,8 +142,9 @@ def cast_to_signed(arr):
""" """
return _CAPI_DGLArrayCastToSigned(arr) return _CAPI_DGLArrayCastToSigned(arr)
def get_shared_mem_array(name, shape, dtype): def get_shared_mem_array(name, shape, dtype):
""" Get a tensor from shared memory with specific name """Get a tensor from shared memory with specific name
Parameters Parameters
---------- ----------
...@@ -141,12 +160,15 @@ def get_shared_mem_array(name, shape, dtype): ...@@ -141,12 +160,15 @@ def get_shared_mem_array(name, shape, dtype):
F.tensor F.tensor
The tensor got from shared memory. The tensor got from shared memory.
""" """
new_arr = empty_shared_mem(name, False, shape, F.reverse_data_type_dict[dtype]) new_arr = empty_shared_mem(
name, False, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack() dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def create_shared_mem_array(name, shape, dtype): def create_shared_mem_array(name, shape, dtype):
""" Create a tensor from shared memory with the specific name """Create a tensor from shared memory with the specific name
Parameters Parameters
---------- ----------
...@@ -162,12 +184,15 @@ def create_shared_mem_array(name, shape, dtype): ...@@ -162,12 +184,15 @@ def create_shared_mem_array(name, shape, dtype):
F.tensor F.tensor
The created tensor. The created tensor.
""" """
new_arr = empty_shared_mem(name, True, shape, F.reverse_data_type_dict[dtype]) new_arr = empty_shared_mem(
name, True, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack() dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def exist_shared_mem_array(name): def exist_shared_mem_array(name):
""" Check the existence of shared-memory array. """Check the existence of shared-memory array.
Parameters Parameters
---------- ----------
...@@ -181,23 +206,27 @@ def exist_shared_mem_array(name): ...@@ -181,23 +206,27 @@ def exist_shared_mem_array(name):
""" """
return _CAPI_DGLExistSharedMemArray(name) return _CAPI_DGLExistSharedMemArray(name)
class SparseFormat: class SparseFormat:
"""Format code""" """Format code"""
ANY = 0 ANY = 0
COO = 1 COO = 1
CSR = 2 CSR = 2
CSC = 3 CSC = 3
FORMAT2STR = { FORMAT2STR = {
0 : 'ANY', 0: "ANY",
1 : 'COO', 1: "COO",
2 : 'CSR', 2: "CSR",
3 : 'CSC', 3: "CSC",
} }
@register_object('aten.SparseMatrix')
@register_object("aten.SparseMatrix")
class SparseMatrix(ObjectBase): class SparseMatrix(ObjectBase):
"""Sparse matrix object class in C++ backend.""" """Sparse matrix object class in C++ backend."""
@property @property
def format(self): def format(self):
"""Sparse format enum """Sparse format enum
...@@ -250,17 +279,26 @@ class SparseMatrix(ObjectBase): ...@@ -250,17 +279,26 @@ class SparseMatrix(ObjectBase):
return _CAPI_DGLSparseMatrixGetFlags(self) return _CAPI_DGLSparseMatrixGetFlags(self)
def __getstate__(self): def __getstate__(self):
return self.format, self.num_rows, self.num_cols, self.indices, self.flags return (
self.format,
self.num_rows,
self.num_cols,
self.indices,
self.flags,
)
def __setstate__(self, state): def __setstate__(self, state):
fmt, nrows, ncols, indices, flags = state fmt, nrows, ncols, indices, flags = state
indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices] indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags) _CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags
)
def __repr__(self): def __repr__(self):
return 'SparseMatrix(fmt="{}", shape=({},{}))'.format( return 'SparseMatrix(fmt="{}", shape=({},{}))'.format(
SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols) SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols
)
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
...@@ -270,5 +308,5 @@ _init_api("dgl.ndarray.uvm", __name__) ...@@ -270,5 +308,5 @@ _init_api("dgl.ndarray.uvm", __name__)
# other backend tensors. # other backend tensors.
NULL = { NULL = {
"int64": array(_np.array([], dtype=_np.int64)), "int64": array(_np.array([], dtype=_np.int64)),
"int32": array(_np.array([], dtype=_np.int32)) "int32": array(_np.array([], dtype=_np.int32)),
} }
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
from __future__ import absolute_import from __future__ import absolute_import
import time import time
from enum import Enum
from collections import namedtuple from collections import namedtuple
from enum import Enum
import dgl.backend as F import dgl.backend as F
from ._ffi.function import _init_api
from ._deprecate.nodeflow import NodeFlow
from . import utils from . import utils
from ._deprecate.nodeflow import NodeFlow
from ._ffi.function import _init_api
_init_api("dgl.network") _init_api("dgl.network")
...@@ -19,12 +20,11 @@ _WAIT_TIME_SEC = 3 # 3 seconds ...@@ -19,12 +20,11 @@ _WAIT_TIME_SEC = 3 # 3 seconds
def _network_wait(): def _network_wait():
"""Sleep for a few seconds """Sleep for a few seconds"""
"""
time.sleep(_WAIT_TIME_SEC) time.sleep(_WAIT_TIME_SEC)
def _create_sender(net_type, msg_queue_size=2*1024*1024*1024): def _create_sender(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Sender communicator via C api """Create a Sender communicator via C api
Parameters Parameters
...@@ -34,11 +34,11 @@ def _create_sender(net_type, msg_queue_size=2*1024*1024*1024): ...@@ -34,11 +34,11 @@ def _create_sender(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int msg_queue_size : int
message queue size (2GB by default) message queue size (2GB by default)
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLSenderCreate(net_type, msg_queue_size) return _CAPI_DGLSenderCreate(net_type, msg_queue_size)
def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024): def _create_receiver(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Receiver communicator via C api """Create a Receiver communicator via C api
Parameters Parameters
...@@ -48,7 +48,7 @@ def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024): ...@@ -48,7 +48,7 @@ def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int msg_queue_size : int
message queue size (2GB by default) message queue size (2GB by default)
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLReceiverCreate(net_type, msg_queue_size) return _CAPI_DGLReceiverCreate(net_type, msg_queue_size)
...@@ -64,8 +64,7 @@ def _finalize_sender(sender): ...@@ -64,8 +64,7 @@ def _finalize_sender(sender):
def _finalize_receiver(receiver): def _finalize_receiver(receiver):
"""Finalize Receiver Communicator """Finalize Receiver Communicator"""
"""
_CAPI_DGLFinalizeReceiver(receiver) _CAPI_DGLFinalizeReceiver(receiver)
...@@ -83,7 +82,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id): ...@@ -83,7 +82,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id)) _CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
...@@ -112,7 +111,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender): ...@@ -112,7 +111,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender):
num_sender : int num_sender : int
total number of Sender total number of Sender
""" """
assert num_sender >= 0, 'num_sender cannot be a negative number.' assert num_sender >= 0, "num_sender cannot be a negative number."
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender)) _CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
...@@ -131,19 +130,22 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -131,19 +130,22 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
gidx = nodeflow._graph gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor() node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor() edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor() layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor() flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendNodeFlow(sender, _CAPI_SenderSendNodeFlow(
sender,
int(recv_id), int(recv_id),
gidx, gidx,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
layers_offsets, layers_offsets,
flows_offsets) flows_offsets,
)
def _send_sampler_end_signal(sender, recv_id): def _send_sampler_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver. """Send an epoch-end signal to remote Receiver.
...@@ -155,9 +157,10 @@ def _send_sampler_end_signal(sender, recv_id): ...@@ -155,9 +157,10 @@ def _send_sampler_end_signal(sender, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.' assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_SenderSendSamplerEndSignal(sender, int(recv_id)) _CAPI_SenderSendSamplerEndSignal(sender, int(recv_id))
def _recv_nodeflow(receiver, graph): def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler. """Receive sampled subgraph (NodeFlow) from remote sampler.
...@@ -183,8 +186,8 @@ def _recv_nodeflow(receiver, graph): ...@@ -183,8 +186,8 @@ def _recv_nodeflow(receiver, graph):
class KVMsgType(Enum): class KVMsgType(Enum):
"""Type of kvstore message """Type of kvstore message"""
"""
FINAL = 1 FINAL = 1
INIT = 2 INIT = 2
PUSH = 3 PUSH = 3
...@@ -215,6 +218,7 @@ c_ptr : void* ...@@ -215,6 +218,7 @@ c_ptr : void*
c pointer of message c pointer of message
""" """
def _send_kv_msg(sender, msg, recv_id): def _send_kv_msg(sender, msg, recv_id):
"""Send kvstore message. """Send kvstore message.
...@@ -230,12 +234,8 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -230,12 +234,8 @@ def _send_kv_msg(sender, msg, recv_id):
if msg.type == KVMsgType.PULL: if msg.type == KVMsgType.PULL:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id) tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender, int(recv_id), msg.type.value, msg.rank, msg.name, tensor_id
int(recv_id), )
msg.type.value,
msg.rank,
msg.name,
tensor_id)
elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK): elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape) tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
...@@ -244,20 +244,14 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -244,20 +244,14 @@ def _send_kv_msg(sender, msg, recv_id):
msg.type.value, msg.type.value,
msg.rank, msg.rank,
msg.name, msg.name,
tensor_shape) tensor_shape,
)
elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE): elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender, int(recv_id), msg.type.value, msg.rank, msg.name
int(recv_id), )
msg.type.value,
msg.rank,
msg.name)
elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank)
sender,
int(recv_id),
msg.type.value,
msg.rank)
else: else:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id) tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
data = F.zerocopy_to_dgl_ndarray(msg.data) data = F.zerocopy_to_dgl_ndarray(msg.data)
...@@ -268,7 +262,8 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -268,7 +262,8 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank, msg.rank,
msg.name, msg.name,
tensor_id, tensor_id,
data) data,
)
def _recv_kv_msg(receiver): def _recv_kv_msg(receiver):
...@@ -288,7 +283,9 @@ def _recv_kv_msg(receiver): ...@@ -288,7 +283,9 @@ def _recv_kv_msg(receiver):
rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr) rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr)
if msg_type == KVMsgType.PULL: if msg_type == KVMsgType.PULL:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr)) tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
rank=rank, rank=rank,
...@@ -296,11 +293,14 @@ def _recv_kv_msg(receiver): ...@@ -296,11 +293,14 @@ def _recv_kv_msg(receiver):
id=tensor_id, id=tensor_id,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK): elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr)) tensor_shape = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgShape(msg_ptr)
)
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
rank=rank, rank=rank,
...@@ -308,7 +308,8 @@ def _recv_kv_msg(receiver): ...@@ -308,7 +308,8 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=tensor_shape, shape=tensor_shape,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE): elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
...@@ -319,7 +320,8 @@ def _recv_kv_msg(receiver): ...@@ -319,7 +320,8 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg( msg = KVStoreMsg(
...@@ -329,11 +331,14 @@ def _recv_kv_msg(receiver): ...@@ -329,11 +331,14 @@ def _recv_kv_msg(receiver):
id=None, id=None,
data=None, data=None,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
else: else:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr)) tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr)) data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr))
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
...@@ -342,25 +347,34 @@ def _recv_kv_msg(receiver): ...@@ -342,25 +347,34 @@ def _recv_kv_msg(receiver):
id=tensor_id, id=tensor_id,
data=data, data=data,
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr,
)
return msg return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value) raise RuntimeError("Unknown message type: %d" % msg_type.value)
def _clear_kv_msg(msg): def _clear_kv_msg(msg):
"""Clear data of kvstore message """Clear data of kvstore message"""
"""
F.sync() F.sync()
if msg.c_ptr is not None: if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr) _CAPI_DeleteKVMsg(msg.c_ptr)
def _fast_pull(name, id_tensor, def _fast_pull(
machine_count, group_count, machine_id, client_id, name,
partition_book, g2l, local_data, id_tensor,
sender, receiver): machine_count,
""" Pull message group_count,
machine_id,
client_id,
partition_book,
g2l,
local_data,
sender,
receiver,
):
"""Pull message
Parameters Parameters
---------- ----------
...@@ -393,17 +407,33 @@ def _fast_pull(name, id_tensor, ...@@ -393,17 +407,33 @@ def _fast_pull(name, id_tensor,
target tensor target tensor
""" """
if g2l is not None: if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id, res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor), F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book), F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data), F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l', sender,
F.zerocopy_to_dgl_ndarray(g2l)) receiver,
"has_g2l",
F.zerocopy_to_dgl_ndarray(g2l),
)
else: else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id, res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor), F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book), F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data), F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l') sender,
receiver,
"no_g2l",
)
return F.zerocopy_from_dgl_ndarray(res_tensor) return F.zerocopy_from_dgl_ndarray(res_tensor)
...@@ -14,20 +14,22 @@ with "[NN] XXX module". ...@@ -14,20 +14,22 @@ with "[NN] XXX module".
""" """
import importlib import importlib
import sys
import os import os
import sys
from ..backend import backend_name
from ..utils import expand_as_pair
# [BarclayII] Not sure what's going on with pylint. # [BarclayII] Not sure what's going on with pylint.
# Possible issue: https://github.com/PyCQA/pylint/issues/2648 # Possible issue: https://github.com/PyCQA/pylint/issues/2648
from . import functional # pylint: disable=import-self from . import functional # pylint: disable=import-self
from ..backend import backend_name
from ..utils import expand_as_pair
def _load_backend(mod_name): def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items(): for api, obj in mod.__dict__.items():
setattr(thismod, api, obj) setattr(thismod, api, obj)
_load_backend(backend_name) _load_backend(backend_name)
"""MXNet modules for graph convolutions.""" """MXNet modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .tagconv import TAGConv
from .gatconv import GATConv
from .sageconv import SAGEConv
from .gatedgraphconv import GatedGraphConv
from .chebconv import ChebConv
from .agnnconv import AGNNConv from .agnnconv import AGNNConv
from .appnpconv import APPNPConv from .appnpconv import APPNPConv
from .chebconv import ChebConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv from .densesageconv import DenseSAGEConv
from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv from .edgeconv import EdgeConv
from .gatconv import GATConv
from .gatedgraphconv import GatedGraphConv
from .ginconv import GINConv from .ginconv import GINConv
from .gmmconv import GMMConv from .gmmconv import GMMConv
from .graphconv import GraphConv
from .nnconv import NNConv from .nnconv import NNConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .sgconv import SGConv from .sgconv import SGConv
from .tagconv import TAGConv
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv', 'GATConv', __all__ = [
'SAGEConv', 'GatedGraphConv', 'ChebConv', 'AGNNConv', "GraphConv",
'APPNPConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv', "TAGConv",
'EdgeConv', 'GINConv', 'GMMConv', 'NNConv', 'SGConv'] "RelGraphConv",
"GATConv",
"SAGEConv",
"GatedGraphConv",
"ChebConv",
"AGNNConv",
"APPNPConv",
"DenseGraphConv",
"DenseSAGEConv",
"DenseChebConv",
"EdgeConv",
"GINConv",
"GMMConv",
"NNConv",
"SGConv",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment