Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
"""Built-in function base class"""
from __future__ import absolute_import
__all__ = ['BuiltinFunction', 'TargetCode']
__all__ = ["BuiltinFunction", "TargetCode"]
class TargetCode(object):
......@@ -10,6 +10,7 @@ class TargetCode(object):
Note: must be consistent with the target code definition in C++ side:
src/kernel/binary_reduce_common.h
"""
SRC = 0
DST = 1
EDGE = 2
......@@ -23,6 +24,7 @@ class TargetCode(object):
class BuiltinFunction(object):
"""Base builtin function class."""
@property
def name(self):
"""Return the name of this builtin function."""
......
......@@ -4,16 +4,15 @@ from __future__ import absolute_import
import sys
from .base import BuiltinFunction, TargetCode
from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var
from .base import BuiltinFunction, TargetCode
class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class."""
def _invoke(self, graph, edge_frame, out_size, edge_map=None,
out_map=None):
def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
"""Symbolic computation of this builtin function to create
runtime.executor
"""
......@@ -28,21 +27,28 @@ class ReduceFunction(BuiltinFunction):
class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
def __init__(self, name, msg_field, out_field):
self._name = name
self.msg_field = msg_field
self.out_field = out_field
def _invoke(self, graph, edge_frame, out_size, edge_map=None,
out_map=None):
def _invoke(self, graph, edge_frame, out_size, edge_map=None, out_map=None):
"""Symbolic execution of this builtin function"""
reducer = self._name
graph = var.GRAPH(graph)
edge_map = var.MAP(edge_map)
out_map = var.MAP(out_map)
edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field))
return ir.COPY_REDUCE(reducer, graph, TargetCode.EDGE, edge_data,
out_size, edge_map, out_map)
return ir.COPY_REDUCE(
reducer,
graph,
TargetCode.EDGE,
edge_data,
out_size,
edge_map,
out_map,
)
@property
def name(self):
......@@ -53,6 +59,7 @@ class SimpleReduceFunction(ReduceFunction):
# Generate all following reducer functions:
# sum, max, min, mean, prod
def _gen_reduce_builtin(reducer):
docstring = """Builtin reduce function that aggregates messages by {0}.
......@@ -73,10 +80,13 @@ def _gen_reduce_builtin(reducer):
>>> import torch
>>> def reduce_func(nodes):
>>> return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}}
""".format(reducer)
""".format(
reducer
)
def func(msg, out):
return SimpleReduceFunction(reducer, msg, out)
func.__name__ = str(reducer)
func.__qualname__ = str(reducer)
func.__doc__ = docstring
......
"""Module for various graph generator functions."""
from . import backend as F
from . import convert
from . import random
from . import convert, random
__all__ = ["rand_graph", "rand_bipartite"]
__all__ = ['rand_graph', 'rand_bipartite']
def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
"""Generate a random graph of the given number of nodes/edges and return.
......@@ -46,20 +46,28 @@ def rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):
ndata_schemes={}
edata_schemes={})
"""
#TODO(minjie): support RNG as one of the arguments.
# TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_nodes * num_nodes, num_edges, replace=False)
eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_nodes)
cols = F.zerocopy_from_numpy(eids % num_nodes)
rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device)
return convert.graph((rows, cols),
num_nodes=num_nodes,
idtype=idtype, device=device)
def rand_bipartite(utype, etype, vtype,
num_src_nodes, num_dst_nodes, num_edges,
idtype=F.int64, device=F.cpu()):
return convert.graph(
(rows, cols), num_nodes=num_nodes, idtype=idtype, device=device
)
def rand_bipartite(
utype,
etype,
vtype,
num_src_nodes,
num_dst_nodes,
num_edges,
idtype=F.int64,
device=F.cpu(),
):
"""Generate a random uni-directional bipartite graph and return.
It uniformly chooses ``num_edges`` from all possible node pairs and form a graph.
......@@ -107,13 +115,18 @@ def rand_bipartite(utype, etype, vtype,
num_edges={('user', 'buys', 'game'): 10},
metagraph=[('user', 'game', 'buys')])
"""
#TODO(minjie): support RNG as one of the arguments.
eids = random.choice(num_src_nodes * num_dst_nodes, num_edges, replace=False)
# TODO(minjie): support RNG as one of the arguments.
eids = random.choice(
num_src_nodes * num_dst_nodes, num_edges, replace=False
)
eids = F.zerocopy_to_numpy(eids)
rows = F.zerocopy_from_numpy(eids // num_dst_nodes)
cols = F.zerocopy_from_numpy(eids % num_dst_nodes)
rows = F.copy_to(F.astype(rows, idtype), device)
cols = F.copy_to(F.astype(cols, idtype), device)
return convert.heterograph({(utype, etype, vtype): (rows, cols)},
{utype: num_src_nodes, vtype: num_dst_nodes},
idtype=idtype, device=device)
return convert.heterograph(
{(utype, etype, vtype): (rows, cols)},
{utype: num_src_nodes, vtype: num_dst_nodes},
idtype=idtype,
device=device,
)
......@@ -8,5 +8,5 @@
This package is experimental and the interfaces may be subject
to changes in future releases.
"""
from .fps import *
from .edge_coarsening import *
from .fps import *
"""Python interfaces to DGL farthest point sampler."""
import numpy as np
from .._ffi.base import DGLError
from .._ffi.function import _init_api
from .. import backend as F
from .. import ndarray as nd
from .._ffi.base import DGLError
from .._ffi.function import _init_api
def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
def _farthest_point_sampler(
data, batch_size, sample_points, dist, start_idx, result
):
r"""Farthest Point Sampler
Parameters
......@@ -32,14 +35,19 @@ def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, re
assert F.shape(data)[0] >= sample_points * batch_size
assert F.shape(data)[0] % batch_size == 0
_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data),
batch_size, sample_points,
F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result))
_CAPI_FarthestPointSampler(
F.zerocopy_to_dgl_ndarray(data),
batch_size,
sample_points,
F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result),
)
def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True):
def _neighbor_matching(
graph_idx, num_nodes, edge_weights=None, relabel_idx=True
):
"""
Description
-----------
......@@ -82,7 +90,11 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
if edge_weights is not None:
edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights)
node_label = F.full_1d(
num_nodes, -1, getattr(F, graph_idx.dtype), F.to_backend_ctx(graph_idx.ctx))
num_nodes,
-1,
getattr(F, graph_idx.dtype),
F.to_backend_ctx(graph_idx.ctx),
)
node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label)
_CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi)
if F.reduce_sum(node_label < 0).item() != 0:
......@@ -99,4 +111,4 @@ def _neighbor_matching(graph_idx, num_nodes, edge_weights=None, relabel_idx=True
return node_label
_init_api('dgl.geometry', __name__)
_init_api("dgl.geometry", __name__)
......@@ -3,7 +3,8 @@
from .. import remove_self_loop
from .capi import _neighbor_matching
__all__ = ['neighbor_matching']
__all__ = ["neighbor_matching"]
def neighbor_matching(graph, e_weights=None, relabel_idx=True):
r"""
......@@ -48,13 +49,16 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True):
>>> res = neighbor_matching(g)
tensor([0, 1, 1])
"""
assert graph.is_homogeneous, \
"The graph used in graph node matching must be homogeneous"
assert (
graph.is_homogeneous
), "The graph used in graph node matching must be homogeneous"
if e_weights is not None:
graph.edata['e_weights'] = e_weights
graph.edata["e_weights"] = e_weights
graph = remove_self_loop(graph)
e_weights = graph.edata['e_weights']
graph.edata.pop('e_weights')
e_weights = graph.edata["e_weights"]
graph.edata.pop("e_weights")
else:
graph = remove_self_loop(graph)
return _neighbor_matching(graph._graph, graph.num_nodes(), e_weights, relabel_idx)
return _neighbor_matching(
graph._graph, graph.num_nodes(), e_weights, relabel_idx
)
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name
# pylint: disable=no-member, invalid-name
from .. import backend as F
from ..base import DGLError
from .capi import _farthest_point_sampler
__all__ = ['farthest_point_sampler']
__all__ = ["farthest_point_sampler"]
def farthest_point_sampler(pos, npoints, start_idx=None):
......@@ -50,12 +49,16 @@ def farthest_point_sampler(pos, npoints, start_idx=None):
pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64,
ctx=ctx, low=0, high=N-1)
start_idx = F.randint(
shape=(B,), dtype=F.int64, ctx=ctx, low=0, high=N - 1
)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
raise DGLError(
"Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx
)
)
start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
......
"""Module for global configuration operators."""
from ._ffi.function import _init_api
__all__ = ["is_libxsmm_enabled", "use_libxsmm"]
def use_libxsmm(flag):
r"""Set whether DGL uses libxsmm at runtime.
......@@ -21,6 +21,7 @@ def use_libxsmm(flag):
"""
_CAPI_DGLConfigSetLibxsmm(flag)
def is_libxsmm_enabled():
r"""Get whether the use_libxsmm flag is turned on.
......@@ -35,4 +36,5 @@ def is_libxsmm_enabled():
"""
return _CAPI_DGLConfigGetLibxsmm()
_init_api("dgl.global_config")
"""Module for graph index class definition."""
from __future__ import absolute_import
import numpy as np
import networkx as nx
import numpy as np
import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from .base import DGLError, dgl_warning
from . import backend as F
from . import utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning
class BoolFlag(object):
"""Bool flag with unknown value"""
BOOL_UNKNOWN = -1
BOOL_FALSE = 0
BOOL_TRUE = 1
@register_object('graph.Graph')
@register_object("graph.Graph")
class GraphIndex(ObjectBase):
"""Graph index object.
......@@ -33,6 +36,7 @@ class GraphIndex(ObjectBase):
- `dgl.graph_index.from_csr`
- `dgl.graph_index.from_coo`
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
obj._readonly = None # python-side cache of the flag
......@@ -53,13 +57,15 @@ class GraphIndex(ObjectBase):
# Pickle compatibility check
# TODO: we should store a storage version number in later releases.
if isinstance(state, tuple) and len(state) == 5:
dgl_warning("The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3")
dgl_warning(
"The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3"
)
num_nodes, _, readonly, src, dst = state
elif isinstance(state, tuple) and len(state) == 4:
# post-0.4.3.
num_nodes, readonly, src, dst = state
else:
raise IOError('Unrecognized storage format.')
raise IOError("Unrecognized storage format.")
self._cache = {}
self._readonly = readonly
......@@ -68,7 +74,8 @@ class GraphIndex(ObjectBase):
src.todgltensor(),
dst.todgltensor(),
int(num_nodes),
readonly)
readonly,
)
def add_nodes(self, num):
"""Add nodes.
......@@ -240,7 +247,9 @@ class GraphIndex(ObjectBase):
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array))
return utils.toindex(
_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)
)
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
......@@ -257,8 +266,9 @@ class GraphIndex(ObjectBase):
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLGraphPredecessors(
self, int(v), int(radius)))
return utils.toindex(
_CAPI_DGLGraphPredecessors(self, int(v), int(radius))
)
def successors(self, v, radius=1):
"""Return the successors of the node.
......@@ -275,8 +285,9 @@ class GraphIndex(ObjectBase):
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLGraphSuccessors(
self, int(v), int(radius)))
return utils.toindex(
_CAPI_DGLGraphSuccessors(self, int(v), int(radius))
)
def edge_id(self, u, v):
"""Return the id array of all edges between u and v.
......@@ -432,7 +443,7 @@ class GraphIndex(ObjectBase):
"""
_CAPI_DGLSortAdj(self)
@utils.cached_member(cache='_cache', prefix='edges')
@utils.cached_member(cache="_cache", prefix="edges")
def edges(self, order=None):
"""Return all the edges
......@@ -606,7 +617,7 @@ class GraphIndex(ObjectBase):
e_array = e.todgltensor()
return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='scipy_adj')
@utils.cached_member(cache="_cache", prefix="scipy_adj")
def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
......@@ -631,8 +642,10 @@ class GraphIndex(ObjectBase):
The scipy representation of adjacency matrix.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
raise DGLError(
'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
if return_edge_ids is None:
dgl_warning(
......@@ -640,17 +653,24 @@ class GraphIndex(ObjectBase):
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
FutureWarning,
)
return_edge_ids = True
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
data = (
utils.toindex(rst(2)).tonumpy()
if return_edge_ids
else np.ones_like(indices)
)
n = self.number_of_nodes()
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(n, n))
elif fmt == 'coo':
return scipy.sparse.csr_matrix(
(data, indices, indptr), shape=(n, n)
)
elif fmt == "coo":
idx = utils.toindex(rst(0)).tonumpy()
n = self.number_of_nodes()
m = self.number_of_edges()
......@@ -660,7 +680,7 @@ class GraphIndex(ObjectBase):
else:
raise Exception("unknown format")
@utils.cached_member(cache='_cache', prefix='immu_gidx')
@utils.cached_member(cache="_cache", prefix="immu_gidx")
def get_immutable_gidx(self, ctx):
"""Create an immutable graph index and copy to the given device context.
......@@ -717,8 +737,10 @@ class GraphIndex(ObjectBase):
if shuffle is not required.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
raise DGLError(
'Expect bool value for "transpose" arg,'
" but got %s." % (type(transpose))
)
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr":
......@@ -726,8 +748,11 @@ class GraphIndex(ObjectBase):
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)
spmat = F.sparse_matrix(dat, ('csr', indices, indptr),
(self.number_of_nodes(), self.number_of_nodes()))[0]
spmat = F.sparse_matrix(
dat,
("csr", indices, indptr),
(self.number_of_nodes(), self.number_of_nodes()),
)[0]
return spmat, shuffle
elif fmt == "coo":
## FIXME(minjie): data type
......@@ -736,8 +761,10 @@ class GraphIndex(ObjectBase):
idx = F.reshape(idx, (2, m))
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
n = self.number_of_nodes()
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
adj, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, n))
shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return adj, shuffle_idx
else:
raise Exception("unknown format")
......@@ -783,21 +810,21 @@ class GraphIndex(ObjectBase):
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes()
m = self.number_of_edges()
if typestr == 'in':
if typestr == "in":
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'out':
inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == "out":
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif typestr == 'both':
inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
elif typestr == "both":
# first remove entries for self loops
mask = F.logical_not(F.equal(src, dst))
src = F.boolean_mask(src, mask)
......@@ -812,10 +839,12 @@ class GraphIndex(ObjectBase):
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
inc, shuffle_idx = F.sparse_matrix(dat, ("coo", idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
raise DGLError("Invalid incidence matrix type: %s" % str(typestr))
shuffle_idx = (
utils.toindex(shuffle_idx) if shuffle_idx is not None else None
)
return inc, shuffle_idx
def to_networkx(self):
......@@ -902,7 +931,9 @@ class GraphIndex(ObjectBase):
GraphIndex
The graph index on the given device context.
"""
return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id)
return _CAPI_DGLImmutableGraphCopyTo(
self, ctx.device_type, ctx.device_id
)
def copyto_shared_mem(self, shared_mem_name):
"""Copy this immutable graph index to shared memory.
......@@ -939,7 +970,10 @@ class GraphIndex(ObjectBase):
int
The number of bits needed
"""
if self.number_of_edges() >= 0x80000000 or self.number_of_nodes() >= 0x80000000:
if (
self.number_of_edges() >= 0x80000000
or self.number_of_nodes() >= 0x80000000
):
return 64
else:
return 32
......@@ -961,9 +995,11 @@ class GraphIndex(ObjectBase):
"""
return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
@register_object('graph.Subgraph')
@register_object("graph.Subgraph")
class SubgraphIndex(ObjectBase):
"""Subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
......@@ -1028,16 +1064,15 @@ def from_coo(num_nodes, src, dst, readonly):
dst = utils.toindex(dst)
if readonly:
gidx = _CAPI_DGLGraphCreate(
src.todgltensor(),
dst.todgltensor(),
int(num_nodes),
readonly)
src.todgltensor(), dst.todgltensor(), int(num_nodes), readonly
)
else:
gidx = _CAPI_DGLGraphCreateMutable()
gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst)
return gidx
def from_csr(indptr, indices, direction):
"""Load a graph from CSR arrays.
......@@ -1058,11 +1093,11 @@ def from_csr(indptr, indices, direction):
indptr = utils.toindex(indptr)
indices = utils.toindex(indices)
gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(),
indices.todgltensor(),
direction)
indptr.todgltensor(), indices.todgltensor(), direction
)
return gidx
def from_shared_mem_graph_index(shared_mem_name):
"""Load a graph index from the shared memory.
......@@ -1078,6 +1113,7 @@ def from_shared_mem_graph_index(shared_mem_name):
"""
return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)
def from_networkx(nx_graph, readonly):
"""Convert from networkx graph.
......@@ -1107,7 +1143,7 @@ def from_networkx(nx_graph, readonly):
# nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0:
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
has_edge_id = "id" in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
......@@ -1116,7 +1152,7 @@ def from_networkx(nx_graph, readonly):
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr['id']
eid = attr["id"]
src[eid] = u
dst[eid] = v
else:
......@@ -1131,6 +1167,7 @@ def from_networkx(nx_graph, readonly):
dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, readonly)
def from_scipy_sparse_matrix(adj, readonly):
"""Convert from scipy sparse matrix.
......@@ -1145,7 +1182,7 @@ def from_scipy_sparse_matrix(adj, readonly):
GraphIndex
The graph index.
"""
if adj.getformat() != 'csr' or not readonly:
if adj.getformat() != "csr" or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)
......@@ -1153,6 +1190,7 @@ def from_scipy_sparse_matrix(adj, readonly):
# If the input matrix is csr, we still treat it as multigraph.
return from_csr(adj.indptr, adj.indices, "out")
def from_edge_list(elist, readonly):
"""Convert from an edge list.
......@@ -1172,6 +1210,7 @@ def from_edge_list(elist, readonly):
num_nodes = max(src.max(), dst.max()) + 1
return from_coo(num_nodes, src_ids, dst_ids, readonly)
def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
......@@ -1188,8 +1227,12 @@ def map_to_subgraph_nid(induced_nodes, parent_nids):
utils.Index
Node Ids in the subgraph.
"""
return utils.toindex(_CAPI_DGLMapSubgraphNID(induced_nodes.todgltensor(),
parent_nids.todgltensor()))
return utils.toindex(
_CAPI_DGLMapSubgraphNID(
induced_nodes.todgltensor(), parent_nids.todgltensor()
)
)
def transform_ids(mapping, ids):
"""Transform ids by the given mapping.
......@@ -1206,8 +1249,10 @@ def transform_ids(mapping, ids):
utils.Index
The new ids.
"""
return utils.toindex(_CAPI_DGLMapSubgraphNID(
mapping.todgltensor(), ids.todgltensor()))
return utils.toindex(
_CAPI_DGLMapSubgraphNID(mapping.todgltensor(), ids.todgltensor())
)
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
......@@ -1230,6 +1275,7 @@ def disjoint_union(graphs):
"""
return _CAPI_DGLDisjointUnion(list(graphs))
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
......@@ -1252,14 +1298,13 @@ def disjoint_partition(graph, num_or_size_splits):
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph,
num_or_size_splits.todgltensor())
graph, num_or_size_splits.todgltensor()
)
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph,
int(num_or_size_splits))
rst = _CAPI_DGLDisjointPartitionByNum(graph, int(num_or_size_splits))
return rst
def create_graph_index(graph_data, readonly):
"""Create a graph index object.
......@@ -1289,11 +1334,15 @@ def create_graph_index(graph_data, readonly):
try:
gidx = from_networkx(graph_data, readonly)
except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data))
raise DGLError(
'Error while creating graph from input of type "%s".'
% type(graph_data)
)
return gidx
def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)
_init_api("dgl.graph_index")
This diff is collapsed.
......@@ -3,9 +3,12 @@ from __future__ import absolute_import
from . import backend as F
__all__ = ['base_initializer', 'zero_initializer']
__all__ = ["base_initializer", "zero_initializer"]
def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def base_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""The function signature for feature initializer.
Any customized feature initializer should follow this signature (see
......@@ -44,7 +47,10 @@ def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-arg
"""
raise NotImplementedError
def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
def zero_initializer(
shape, dtype, ctx, id_range
): # pylint: disable=unused-argument
"""Zero feature initializer
Examples
......
"""Utilities for merging graphs."""
import dgl
from . import backend as F
from .base import DGLError
__all__ = ['merge']
__all__ = ["merge"]
def merge(graphs):
r"""Merge a sequence of graphs together into a single graph.
......@@ -62,7 +64,7 @@ def merge(graphs):
"""
if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.')
raise DGLError("The input list of graphs cannot be empty.")
ref = graphs[0]
ntypes = ref.ntypes
......@@ -87,9 +89,15 @@ def merge(graphs):
if len(keys) == 0:
edges_data = None
else:
edges_data = {k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys}
merged_us = F.copy_to(F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device)
merged_vs = F.copy_to(F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device)
edges_data = {
k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys
}
merged_us = F.copy_to(
F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device
)
merged_vs = F.copy_to(
F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device
)
merged.add_edges(merged_us, merged_vs, edges_data, etype)
# Add node data and isolated nodes from next_graph to merged.
......@@ -98,12 +106,16 @@ def merge(graphs):
merged_ntype_id = merged.get_ntype_id(ntype)
next_ntype_id = next_graph.get_ntype_id(ntype)
next_ndata = next_graph._node_frames[next_ntype_id]
node_diff = (next_graph.num_nodes(ntype=ntype) -
merged.num_nodes(ntype=ntype))
node_diff = next_graph.num_nodes(ntype=ntype) - merged.num_nodes(
ntype=ntype
)
n_extra_nodes = max(0, node_diff)
merged.add_nodes(n_extra_nodes, ntype=ntype)
next_nodes = F.arange(
0, next_graph.num_nodes(ntype=ntype), merged.idtype, merged.device
0,
next_graph.num_nodes(ntype=ntype),
merged.idtype,
merged.device,
)
merged._node_frames[merged_ntype_id].update_row(
next_nodes, next_ndata
......
"""dgl sparse class."""
from .diag_matrix import *
from .sp_matrix import *
from .elementwise_op import *
from .matmul import *
from .reduction import * # pylint: disable=W0622
from .sddmm import *
from .reduction import * # pylint: disable=W0622
from .sp_matrix import *
from .unary_diag import *
from .unary_sp import *
from .matmul import *
"""DGL sparse matrix module."""
from typing import Optional, Tuple
import torch
__all__ = [
......
......@@ -6,11 +6,12 @@
# make fork() and openmp work together.
from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
if F.get_preferred_backend() == "pytorch":
# Wrap around torch.multiprocessing...
from torch.multiprocessing import *
# ... and override the Process initializer.
from .pytorch import *
else:
# Just import multiprocessing module.
from multiprocessing import * # pylint: disable=redefined-builtin
from multiprocessing import * # pylint: disable=redefined-builtin
"""PyTorch multiprocessing wrapper."""
from functools import wraps
import random
import traceback
from _thread import start_new_thread
from functools import wraps
import torch
import torch.multiprocessing as mp
from ..utils import create_shared_mem_array, get_shared_mem_array
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
......@@ -31,18 +35,31 @@ def thread_wrapped_func(func):
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
# pylint: disable=missing-docstring
class Process(mp.Process):
# pylint: disable=dangerous-default-value
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
def __init__(
self,
group=None,
target=None,
name=None,
args=(),
kwargs={},
*,
daemon=None
):
target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon)
def _get_shared_mem_name(id_):
return "shared" + str(id_)
def call_once_and_share(func, shape, dtype, rank=0):
"""Invoke the function in a single process of the PyTorch distributed process group,
and share the result with other processes.
......@@ -61,7 +78,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
current_rank = torch.distributed.get_rank()
dist_buf = torch.LongTensor([1])
if torch.distributed.get_backend() == 'nccl':
if torch.distributed.get_backend() == "nccl":
# Use .cuda() to transfer it to the correct device. Should be OK since
# PyTorch recommends the users to call set_device() after getting inside
# torch.multiprocessing.spawn()
......@@ -88,6 +105,7 @@ def call_once_and_share(func, shape, dtype, rank=0):
return result
def shared_tensor(shape, dtype=torch.float32):
"""Create a tensor in shared memory accessible by all processes within the same
``torch.distributed`` process group.
......@@ -106,4 +124,6 @@ def shared_tensor(shape, dtype=torch.float32):
Tensor
The shared tensor.
"""
return call_once_and_share(lambda: torch.empty(*shape, dtype=dtype), shape, dtype)
return call_once_and_share(
lambda: torch.empty(*shape, dtype=dtype), shape, dtype
)
......@@ -9,17 +9,28 @@ from __future__ import absolute_import as _abs
import ctypes
import functools
import operator
import numpy as _np
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F
from ._ffi.function import _init_api
from ._ffi.ndarray import (
DGLContext,
DGLDataType,
NDArrayBase,
_set_class_ndarray,
context,
empty,
empty_shared_mem,
from_dlpack,
numpyasarray,
)
from ._ffi.object import ObjectBase, register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
def __len__(self):
return functools.reduce(operator.mul, self.shape, 1)
......@@ -35,7 +46,10 @@ class NDArray(NDArrayBase):
-------
NDArray
"""
return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(self)
return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(
self
)
def cpu(dev_id=0):
"""Construct a CPU device
......@@ -52,6 +66,7 @@ def cpu(dev_id=0):
"""
return DGLContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
......@@ -67,6 +82,7 @@ def gpu(dev_id=0):
"""
return DGLContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
......@@ -87,6 +103,7 @@ def array(arr, ctx=cpu(0)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data.
......@@ -104,6 +121,7 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True)
def cast_to_signed(arr):
"""Cast this NDArray from unsigned integer to signed one.
......@@ -124,8 +142,9 @@ def cast_to_signed(arr):
"""
return _CAPI_DGLArrayCastToSigned(arr)
def get_shared_mem_array(name, shape, dtype):
""" Get a tensor from shared memory with specific name
"""Get a tensor from shared memory with specific name
Parameters
----------
......@@ -141,12 +160,15 @@ def get_shared_mem_array(name, shape, dtype):
F.tensor
The tensor got from shared memory.
"""
new_arr = empty_shared_mem(name, False, shape, F.reverse_data_type_dict[dtype])
new_arr = empty_shared_mem(
name, False, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def create_shared_mem_array(name, shape, dtype):
""" Create a tensor from shared memory with the specific name
"""Create a tensor from shared memory with the specific name
Parameters
----------
......@@ -162,12 +184,15 @@ def create_shared_mem_array(name, shape, dtype):
F.tensor
The created tensor.
"""
new_arr = empty_shared_mem(name, True, shape, F.reverse_data_type_dict[dtype])
new_arr = empty_shared_mem(
name, True, shape, F.reverse_data_type_dict[dtype]
)
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def exist_shared_mem_array(name):
""" Check the existence of shared-memory array.
"""Check the existence of shared-memory array.
Parameters
----------
......@@ -181,23 +206,27 @@ def exist_shared_mem_array(name):
"""
return _CAPI_DGLExistSharedMemArray(name)
class SparseFormat:
"""Format code"""
ANY = 0
COO = 1
CSR = 2
CSC = 3
FORMAT2STR = {
0 : 'ANY',
1 : 'COO',
2 : 'CSR',
3 : 'CSC',
0: "ANY",
1: "COO",
2: "CSR",
3: "CSC",
}
@register_object('aten.SparseMatrix')
@register_object("aten.SparseMatrix")
class SparseMatrix(ObjectBase):
"""Sparse matrix object class in C++ backend."""
@property
def format(self):
"""Sparse format enum
......@@ -250,17 +279,26 @@ class SparseMatrix(ObjectBase):
return _CAPI_DGLSparseMatrixGetFlags(self)
def __getstate__(self):
return self.format, self.num_rows, self.num_cols, self.indices, self.flags
return (
self.format,
self.num_rows,
self.num_cols,
self.indices,
self.flags,
)
def __setstate__(self, state):
fmt, nrows, ncols, indices, flags = state
indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices]
self.__init_handle_by_constructor__(
_CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags)
_CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags
)
def __repr__(self):
return 'SparseMatrix(fmt="{}", shape=({},{}))'.format(
SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols)
SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols
)
_set_class_ndarray(NDArray)
_init_api("dgl.ndarray")
......@@ -270,5 +308,5 @@ _init_api("dgl.ndarray.uvm", __name__)
# other backend tensors.
NULL = {
"int64": array(_np.array([], dtype=_np.int64)),
"int32": array(_np.array([], dtype=_np.int32))
"int32": array(_np.array([], dtype=_np.int32)),
}
......@@ -2,13 +2,14 @@
from __future__ import absolute_import
import time
from enum import Enum
from collections import namedtuple
from enum import Enum
import dgl.backend as F
from ._ffi.function import _init_api
from ._deprecate.nodeflow import NodeFlow
from . import utils
from ._deprecate.nodeflow import NodeFlow
from ._ffi.function import _init_api
_init_api("dgl.network")
......@@ -19,12 +20,11 @@ _WAIT_TIME_SEC = 3 # 3 seconds
def _network_wait():
"""Sleep for a few seconds
"""
"""Sleep for a few seconds"""
time.sleep(_WAIT_TIME_SEC)
def _create_sender(net_type, msg_queue_size=2*1024*1024*1024):
def _create_sender(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Sender communicator via C api
Parameters
......@@ -34,11 +34,11 @@ def _create_sender(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int
message queue size (2GB by default)
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLSenderCreate(net_type, msg_queue_size)
def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024):
def _create_receiver(net_type, msg_queue_size=2 * 1024 * 1024 * 1024):
"""Create a Receiver communicator via C api
Parameters
......@@ -48,7 +48,7 @@ def _create_receiver(net_type, msg_queue_size=2*1024*1024*1024):
msg_queue_size : int
message queue size (2GB by default)
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
assert net_type in ("socket", "mpi"), "Unknown network type."
return _CAPI_DGLReceiverCreate(net_type, msg_queue_size)
......@@ -64,8 +64,7 @@ def _finalize_sender(sender):
def _finalize_receiver(receiver):
"""Finalize Receiver Communicator
"""
"""Finalize Receiver Communicator"""
_CAPI_DGLFinalizeReceiver(receiver)
......@@ -83,7 +82,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int
Receiver ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
......@@ -112,7 +111,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender):
num_sender : int
total number of Sender
"""
assert num_sender >= 0, 'num_sender cannot be a negative number.'
assert num_sender >= 0, "num_sender cannot be a negative number."
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
......@@ -131,19 +130,22 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int
Receiver ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
assert recv_id >= 0, "recv_id cannot be a negative number."
gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendNodeFlow(sender,
int(recv_id),
gidx,
node_mapping,
edge_mapping,
layers_offsets,
flows_offsets)
_CAPI_SenderSendNodeFlow(
sender,
int(recv_id),
gidx,
node_mapping,
edge_mapping,
layers_offsets,
flows_offsets,
)
def _send_sampler_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver.
......@@ -155,9 +157,10 @@ def _send_sampler_end_signal(sender, recv_id):
recv_id : int
Receiver ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
assert recv_id >= 0, "recv_id cannot be a negative number."
_CAPI_SenderSendSamplerEndSignal(sender, int(recv_id))
def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler.
......@@ -183,8 +186,8 @@ def _recv_nodeflow(receiver, graph):
class KVMsgType(Enum):
"""Type of kvstore message
"""
"""Type of kvstore message"""
FINAL = 1
INIT = 2
PUSH = 3
......@@ -215,6 +218,7 @@ c_ptr : void*
c pointer of message
"""
def _send_kv_msg(sender, msg, recv_id):
"""Send kvstore message.
......@@ -230,12 +234,8 @@ def _send_kv_msg(sender, msg, recv_id):
if msg.type == KVMsgType.PULL:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank,
msg.name,
tensor_id)
sender, int(recv_id), msg.type.value, msg.rank, msg.name, tensor_id
)
elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg(
......@@ -244,20 +244,14 @@ def _send_kv_msg(sender, msg, recv_id):
msg.type.value,
msg.rank,
msg.name,
tensor_shape)
tensor_shape,
)
elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank,
msg.name)
sender, int(recv_id), msg.type.value, msg.rank, msg.name
)
elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank)
_CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank)
else:
tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
data = F.zerocopy_to_dgl_ndarray(msg.data)
......@@ -268,7 +262,8 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank,
msg.name,
tensor_id,
data)
data,
)
def _recv_kv_msg(receiver):
......@@ -288,7 +283,9 @@ def _recv_kv_msg(receiver):
rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr)
if msg_type == KVMsgType.PULL:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr))
tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
msg = KVStoreMsg(
type=msg_type,
rank=rank,
......@@ -296,11 +293,14 @@ def _recv_kv_msg(receiver):
id=tensor_id,
data=None,
shape=None,
c_ptr=msg_ptr)
c_ptr=msg_ptr,
)
return msg
elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr))
tensor_shape = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgShape(msg_ptr)
)
msg = KVStoreMsg(
type=msg_type,
rank=rank,
......@@ -308,7 +308,8 @@ def _recv_kv_msg(receiver):
id=None,
data=None,
shape=tensor_shape,
c_ptr=msg_ptr)
c_ptr=msg_ptr,
)
return msg
elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
......@@ -319,7 +320,8 @@ def _recv_kv_msg(receiver):
id=None,
data=None,
shape=None,
c_ptr=msg_ptr)
c_ptr=msg_ptr,
)
return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg(
......@@ -329,11 +331,14 @@ def _recv_kv_msg(receiver):
id=None,
data=None,
shape=None,
c_ptr=msg_ptr)
c_ptr=msg_ptr,
)
return msg
else:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr))
tensor_id = F.zerocopy_from_dgl_ndarray(
_CAPI_ReceiverGetKVMsgID(msg_ptr)
)
data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr))
msg = KVStoreMsg(
type=msg_type,
......@@ -342,25 +347,34 @@ def _recv_kv_msg(receiver):
id=tensor_id,
data=data,
shape=None,
c_ptr=msg_ptr)
c_ptr=msg_ptr,
)
return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value)
raise RuntimeError("Unknown message type: %d" % msg_type.value)
def _clear_kv_msg(msg):
"""Clear data of kvstore message
"""
"""Clear data of kvstore message"""
F.sync()
if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr)
def _fast_pull(name, id_tensor,
machine_count, group_count, machine_id, client_id,
partition_book, g2l, local_data,
sender, receiver):
""" Pull message
def _fast_pull(
name,
id_tensor,
machine_count,
group_count,
machine_id,
client_id,
partition_book,
g2l,
local_data,
sender,
receiver,
):
"""Pull message
Parameters
----------
......@@ -393,17 +407,33 @@ def _fast_pull(name, id_tensor,
target tensor
"""
if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l',
F.zerocopy_to_dgl_ndarray(g2l))
res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender,
receiver,
"has_g2l",
F.zerocopy_to_dgl_ndarray(g2l),
)
else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l')
res_tensor = _CAPI_FastPull(
name,
machine_id,
machine_count,
group_count,
client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender,
receiver,
"no_g2l",
)
return F.zerocopy_from_dgl_ndarray(res_tensor)
......@@ -14,20 +14,22 @@ with "[NN] XXX module".
"""
import importlib
import sys
import os
import sys
from ..backend import backend_name
from ..utils import expand_as_pair
# [BarclayII] Not sure what's going on with pylint.
# Possible issue: https://github.com/PyCQA/pylint/issues/2648
from . import functional # pylint: disable=import-self
from . import functional # pylint: disable=import-self
from ..backend import backend_name
from ..utils import expand_as_pair
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name)
"""MXNet modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ, invalid-name
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .tagconv import TAGConv
from .gatconv import GATConv
from .sageconv import SAGEConv
from .gatedgraphconv import GatedGraphConv
from .chebconv import ChebConv
from .agnnconv import AGNNConv
from .appnpconv import APPNPConv
from .chebconv import ChebConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv
from .gatconv import GATConv
from .gatedgraphconv import GatedGraphConv
from .ginconv import GINConv
from .gmmconv import GMMConv
from .graphconv import GraphConv
from .nnconv import NNConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .sgconv import SGConv
from .tagconv import TAGConv
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv', 'GATConv',
'SAGEConv', 'GatedGraphConv', 'ChebConv', 'AGNNConv',
'APPNPConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv',
'EdgeConv', 'GINConv', 'GMMConv', 'NNConv', 'SGConv']
__all__ = [
"GraphConv",
"TAGConv",
"RelGraphConv",
"GATConv",
"SAGEConv",
"GatedGraphConv",
"ChebConv",
"AGNNConv",
"APPNPConv",
"DenseGraphConv",
"DenseSAGEConv",
"DenseChebConv",
"EdgeConv",
"GINConv",
"GMMConv",
"NNConv",
"SGConv",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment