Unverified Commit 98325b10 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4691)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent c24e285a
"""Negative sampling APIs""" """Negative sampling APIs"""
from numpy.polynomial import polynomial from numpy.polynomial import polynomial
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .._ffi.function import _init_api
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
__all__ = [ __all__ = ["global_uniform_negative_sampling"]
'global_uniform_negative_sampling']
def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invalid-name def _calc_redundancy(
k_hat, num_edges, num_pairs, r=3
): # pylint: disable=invalid-name
# pylint: disable=invalid-name # pylint: disable=invalid-name
# Calculates the number of samples required based on a lower-bound # Calculates the number of samples required based on a lower-bound
# of the expected number of negative samples, based on N draws from # of the expected number of negative samples, based on N draws from
...@@ -24,18 +27,24 @@ def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invali ...@@ -24,18 +27,24 @@ def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invali
p_m = num_edges / num_pairs p_m = num_edges / num_pairs
p_k = 1 - p_m p_k = 1 - p_m
a = p_k ** 2 a = p_k**2
b = -p_k * (2 * k_hat + r ** 2 * p_m) b = -p_k * (2 * k_hat + r**2 * p_m)
c = k_hat ** 2 c = k_hat**2
poly = polynomial.Polynomial([c, b, a]) poly = polynomial.Polynomial([c, b, a])
N = poly.roots()[-1] N = poly.roots()[-1]
redundancy = N / k_hat - 1. redundancy = N / k_hat - 1.0
return redundancy return redundancy
def global_uniform_negative_sampling( def global_uniform_negative_sampling(
g, num_samples, exclude_self_loops=True, replace=False, etype=None, g,
redundancy=None): num_samples,
exclude_self_loops=True,
replace=False,
etype=None,
redundancy=None,
):
"""Performs negative sampling, which generate source-destination pairs such that """Performs negative sampling, which generate source-destination pairs such that
edges with the given type do not exist. edges with the given type do not exist.
...@@ -95,13 +104,24 @@ def global_uniform_negative_sampling( ...@@ -95,13 +104,24 @@ def global_uniform_negative_sampling(
exclude_self_loops = exclude_self_loops and (utype == vtype) exclude_self_loops = exclude_self_loops and (utype == vtype)
redundancy = _calc_redundancy( redundancy = _calc_redundancy(
num_samples, g.num_edges(etype), g.num_nodes(utype) * g.num_nodes(vtype)) num_samples, g.num_edges(etype), g.num_nodes(utype) * g.num_nodes(vtype)
)
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
src, dst = _CAPI_DGLGlobalUniformNegativeSampling( src, dst = _CAPI_DGLGlobalUniformNegativeSampling(
g._graph, etype_id, num_samples, 3, exclude_self_loops, replace, redundancy) g._graph,
etype_id,
num_samples,
3,
exclude_self_loops,
replace,
redundancy,
)
return F.from_dgl_nd(src), F.from_dgl_nd(dst) return F.from_dgl_nd(src), F.from_dgl_nd(dst)
DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func( DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func(
global_uniform_negative_sampling) global_uniform_negative_sampling
)
_init_api('dgl.sampling.negative', __name__) _init_api("dgl.sampling.negative", __name__)
"""Node2vec random walk""" """Node2vec random walk"""
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
from .. import utils from .. import utils
from .._ffi.function import _init_api
# pylint: disable=invalid-name # pylint: disable=invalid-name
__all__ = ['node2vec_random_walk'] __all__ = ["node2vec_random_walk"]
def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=False): def node2vec_random_walk(
g, nodes, p, q, walk_length, prob=None, return_eids=False
):
""" """
Generate random walk traces from an array of starting nodes based on the node2vec model. Generate random walk traces from an array of starting nodes based on the node2vec model.
Paper: `node2vec: Scalable Feature Learning for Networks Paper: `node2vec: Scalable Feature Learning for Networks
...@@ -82,14 +85,16 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal ...@@ -82,14 +85,16 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal
assert g.device == F.cpu(), "Graph must be on CPU." assert g.device == F.cpu(), "Graph must be on CPU."
gidx = g._graph gidx = g._graph
nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, 'nodes')) nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, "nodes"))
if prob is None: if prob is None:
prob_nd = nd.array([], ctx=nodes.ctx) prob_nd = nd.array([], ctx=nodes.ctx)
else: else:
prob_nd = F.to_dgl_nd(g.edata[prob]) prob_nd = F.to_dgl_nd(g.edata[prob])
traces, eids = _CAPI_DGLSamplingNode2vec(gidx, nodes, p, q, walk_length, prob_nd) traces, eids = _CAPI_DGLSamplingNode2vec(
gidx, nodes, p, q, walk_length, prob_nd
)
traces = F.from_dgl_nd(traces) traces = F.from_dgl_nd(traces)
eids = F.from_dgl_nd(eids) eids = F.from_dgl_nd(eids)
...@@ -97,4 +102,4 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal ...@@ -97,4 +102,4 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal
return (traces, eids) if return_eids else traces return (traces, eids) if return_eids else traces
_init_api('dgl.sampling.randomwalks', __name__) _init_api("dgl.sampling.randomwalks", __name__)
"""PinSAGE sampler & related functions and classes""" """PinSAGE sampler & related functions and classes"""
import numpy as np import numpy as np
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import convert from .. import convert, utils
from .._ffi.function import _init_api
from .randomwalks import random_walk from .randomwalks import random_walk
from .. import utils
def _select_pinsage_neighbors(src, dst, num_samples_per_node, k): def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
"""Determine the neighbors for PinSAGE algorithm from the given random walk traces. """Determine the neighbors for PinSAGE algorithm from the given random walk traces.
...@@ -16,12 +16,15 @@ def _select_pinsage_neighbors(src, dst, num_samples_per_node, k): ...@@ -16,12 +16,15 @@ def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
""" """
src = F.to_dgl_nd(src) src = F.to_dgl_nd(src)
dst = F.to_dgl_nd(dst) dst = F.to_dgl_nd(dst)
src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(src, dst, num_samples_per_node, k) src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(
src, dst, num_samples_per_node, k
)
src = F.from_dgl_nd(src) src = F.from_dgl_nd(src)
dst = F.from_dgl_nd(dst) dst = F.from_dgl_nd(dst)
counts = F.from_dgl_nd(counts) counts = F.from_dgl_nd(counts)
return (src, dst, counts) return (src, dst, counts)
class RandomWalkNeighborSampler(object): class RandomWalkNeighborSampler(object):
"""PinSage-like neighbor sampler extended to any heterogeneous graphs. """PinSage-like neighbor sampler extended to any heterogeneous graphs.
...@@ -72,8 +75,17 @@ class RandomWalkNeighborSampler(object): ...@@ -72,8 +75,17 @@ class RandomWalkNeighborSampler(object):
-------- --------
See examples in :any:`PinSAGESampler`. See examples in :any:`PinSAGESampler`.
""" """
def __init__(self, G, num_traversals, termination_prob,
num_random_walks, num_neighbors, metapath=None, weight_column='weights'): def __init__(
self,
G,
num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
metapath=None,
weight_column="weights",
):
self.G = G self.G = G
self.weight_column = weight_column self.weight_column = weight_column
self.num_random_walks = num_random_walks self.num_random_walks = num_random_walks
...@@ -82,19 +94,25 @@ class RandomWalkNeighborSampler(object): ...@@ -82,19 +94,25 @@ class RandomWalkNeighborSampler(object):
if metapath is None: if metapath is None:
if len(G.ntypes) > 1 or len(G.etypes) > 1: if len(G.ntypes) > 1 or len(G.etypes) > 1:
raise ValueError('Metapath must be specified if the graph is homogeneous.') raise ValueError(
"Metapath must be specified if the graph is homogeneous."
)
metapath = [G.canonical_etypes[0]] metapath = [G.canonical_etypes[0]]
start_ntype = G.to_canonical_etype(metapath[0])[0] start_ntype = G.to_canonical_etype(metapath[0])[0]
end_ntype = G.to_canonical_etype(metapath[-1])[-1] end_ntype = G.to_canonical_etype(metapath[-1])[-1]
if start_ntype != end_ntype: if start_ntype != end_ntype:
raise ValueError('The metapath must start and end at the same node type.') raise ValueError(
"The metapath must start and end at the same node type."
)
self.ntype = start_ntype self.ntype = start_ntype
self.metapath_hops = len(metapath) self.metapath_hops = len(metapath)
self.metapath = metapath self.metapath = metapath
self.full_metapath = metapath * num_traversals self.full_metapath = metapath * num_traversals
restart_prob = np.zeros(self.metapath_hops * num_traversals) restart_prob = np.zeros(self.metapath_hops * num_traversals)
restart_prob[self.metapath_hops::self.metapath_hops] = termination_prob restart_prob[
self.metapath_hops :: self.metapath_hops
] = termination_prob
restart_prob = F.tensor(restart_prob, dtype=F.float32) restart_prob = F.tensor(restart_prob, dtype=F.float32)
self.restart_prob = F.copy_to(restart_prob, G.device) self.restart_prob = F.copy_to(restart_prob, G.device)
...@@ -116,20 +134,30 @@ class RandomWalkNeighborSampler(object): ...@@ -116,20 +134,30 @@ class RandomWalkNeighborSampler(object):
A homogeneous graph constructed by selecting neighbors for each given node according A homogeneous graph constructed by selecting neighbors for each given node according
to the algorithm above. to the algorithm above.
""" """
seed_nodes = utils.prepare_tensor(self.G, seed_nodes, 'seed_nodes') seed_nodes = utils.prepare_tensor(self.G, seed_nodes, "seed_nodes")
self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes)) self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes))
seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0) seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0)
paths, _ = random_walk( paths, _ = random_walk(
self.G, seed_nodes, metapath=self.full_metapath, restart_prob=self.restart_prob) self.G,
src = F.reshape(paths[:, self.metapath_hops::self.metapath_hops], (-1,)) seed_nodes,
metapath=self.full_metapath,
restart_prob=self.restart_prob,
)
src = F.reshape(
paths[:, self.metapath_hops :: self.metapath_hops], (-1,)
)
dst = F.repeat(paths[:, 0], self.num_traversals, 0) dst = F.repeat(paths[:, 0], self.num_traversals, 0)
src, dst, counts = _select_pinsage_neighbors( src, dst, counts = _select_pinsage_neighbors(
src, dst, (self.num_random_walks * self.num_traversals), self.num_neighbors) src,
dst,
(self.num_random_walks * self.num_traversals),
self.num_neighbors,
)
neighbor_graph = convert.heterograph( neighbor_graph = convert.heterograph(
{(self.ntype, '_E', self.ntype): (src, dst)}, {(self.ntype, "_E", self.ntype): (src, dst)},
{self.ntype: self.G.number_of_nodes(self.ntype)} {self.ntype: self.G.number_of_nodes(self.ntype)},
) )
neighbor_graph.edata[self.weight_column] = counts neighbor_graph.edata[self.weight_column] = counts
...@@ -219,13 +247,30 @@ class PinSAGESampler(RandomWalkNeighborSampler): ...@@ -219,13 +247,30 @@ class PinSAGESampler(RandomWalkNeighborSampler):
Graph Convolutional Neural Networks for Web-Scale Recommender Systems Graph Convolutional Neural Networks for Web-Scale Recommender Systems
Ying et al., 2018, https://arxiv.org/abs/1806.01973 Ying et al., 2018, https://arxiv.org/abs/1806.01973
""" """
def __init__(self, G, ntype, other_type, num_traversals, termination_prob,
num_random_walks, num_neighbors, weight_column='weights'): def __init__(
self,
G,
ntype,
other_type,
num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
weight_column="weights",
):
metagraph = G.metagraph() metagraph = G.metagraph()
fw_etype = list(metagraph[ntype][other_type])[0] fw_etype = list(metagraph[ntype][other_type])[0]
bw_etype = list(metagraph[other_type][ntype])[0] bw_etype = list(metagraph[other_type][ntype])[0]
super().__init__(G, num_traversals, super().__init__(
termination_prob, num_random_walks, num_neighbors, G,
metapath=[fw_etype, bw_etype], weight_column=weight_column) num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
metapath=[fw_etype, bw_etype],
weight_column=weight_column,
)
_init_api('dgl.sampling.pinsage', __name__) _init_api("dgl.sampling.pinsage", __name__)
"""Random walk routines """Random walk routines
""" """
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from ..base import DGLError
from .. import ndarray as nd from .. import ndarray as nd
from .. import utils from .. import utils
from .._ffi.function import _init_api
from ..base import DGLError
__all__ = [ __all__ = ["random_walk", "pack_traces"]
'random_walk',
'pack_traces']
def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob=None,
return_eids=False): def random_walk(
g,
nodes,
*,
metapath=None,
length=None,
prob=None,
restart_prob=None,
return_eids=False
):
"""Generate random walk traces from an array of starting nodes based on the given metapath. """Generate random walk traces from an array of starting nodes based on the given metapath.
Each starting node will have one trace generated, which Each starting node will have one trace generated, which
...@@ -167,15 +174,19 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -167,15 +174,19 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
if metapath is None: if metapath is None:
if n_etypes > 1 or n_ntypes > 1: if n_etypes > 1 or n_ntypes > 1:
raise DGLError("metapath not specified and the graph is not homogeneous.") raise DGLError(
"metapath not specified and the graph is not homogeneous."
)
if length is None: if length is None:
raise ValueError("Please specify either the metapath or the random walk length.") raise ValueError(
"Please specify either the metapath or the random walk length."
)
metapath = [0] * length metapath = [0] * length
else: else:
metapath = [g.get_etype_id(etype) for etype in metapath] metapath = [g.get_etype_id(etype) for etype in metapath]
gidx = g._graph gidx = g._graph
nodes = utils.prepare_tensor(g, nodes, 'nodes') nodes = utils.prepare_tensor(g, nodes, "nodes")
nodes = F.to_dgl_nd(nodes) nodes = F.to_dgl_nd(nodes)
# (Xin) Since metapath array is created by us, safe to skip the check # (Xin) Since metapath array is created by us, safe to skip the check
# and keep it on CPU to make max_nodes sanity check easier. # and keep it on CPU to make max_nodes sanity check easier.
...@@ -196,14 +207,18 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -196,14 +207,18 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
# Actual random walk # Actual random walk
if restart_prob is None: if restart_prob is None:
traces, eids, types = _CAPI_DGLSamplingRandomWalk(gidx, nodes, metapath, p_nd) traces, eids, types = _CAPI_DGLSamplingRandomWalk(
gidx, nodes, metapath, p_nd
)
elif F.is_tensor(restart_prob): elif F.is_tensor(restart_prob):
restart_prob = F.to_dgl_nd(restart_prob) restart_prob = F.to_dgl_nd(restart_prob)
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart( traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
gidx, nodes, metapath, p_nd, restart_prob) gidx, nodes, metapath, p_nd, restart_prob
)
elif isinstance(restart_prob, float): elif isinstance(restart_prob, float):
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart( traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob) gidx, nodes, metapath, p_nd, restart_prob
)
else: else:
raise TypeError("restart_prob should be float or Tensor.") raise TypeError("restart_prob should be float or Tensor.")
...@@ -212,6 +227,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -212,6 +227,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
eids = F.from_dgl_nd(eids) eids = F.from_dgl_nd(eids)
return (traces, eids, types) if return_eids else (traces, types) return (traces, eids, types) if return_eids else (traces, types)
def pack_traces(traces, types): def pack_traces(traces, types):
"""Pack the padded traces returned by ``random_walk()`` into a concatenated array. """Pack the padded traces returned by ``random_walk()`` into a concatenated array.
The padding values (-1) are removed, and the length and offset of each trace is The padding values (-1) are removed, and the length and offset of each trace is
...@@ -276,12 +292,18 @@ def pack_traces(traces, types): ...@@ -276,12 +292,18 @@ def pack_traces(traces, types):
>>> vids[1], vtypes[1] >>> vids[1], vtypes[1]
(tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0])) (tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))
""" """
assert F.is_tensor(traces) and F.context(traces) == F.cpu(), "traces must be a CPU tensor" assert (
assert F.is_tensor(types) and F.context(types) == F.cpu(), "types must be a CPU tensor" F.is_tensor(traces) and F.context(traces) == F.cpu()
), "traces must be a CPU tensor"
assert (
F.is_tensor(types) and F.context(types) == F.cpu()
), "types must be a CPU tensor"
traces = F.to_dgl_nd(traces) traces = F.to_dgl_nd(traces)
types = F.to_dgl_nd(types) types = F.to_dgl_nd(types)
concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types) concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(
traces, types
)
concat_vids = F.from_dgl_nd(concat_vids) concat_vids = F.from_dgl_nd(concat_vids)
concat_types = F.from_dgl_nd(concat_types) concat_types = F.from_dgl_nd(concat_types)
...@@ -290,4 +312,5 @@ def pack_traces(traces, types): ...@@ -290,4 +312,5 @@ def pack_traces(traces, types):
return concat_vids, concat_types, lengths, offsets return concat_vids, concat_types, lengths, offsets
_init_api('dgl.sampling.randomwalks', __name__)
_init_api("dgl.sampling.randomwalks", __name__)
"""Module for sparse matrix operators.""" """Module for sparse matrix operators."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from . import backend as F
from . import ndarray as nd from . import ndarray as nd
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F
def infer_broadcast_shape(op, shp1, shp2): def infer_broadcast_shape(op, shp1, shp2):
...@@ -34,9 +35,12 @@ def infer_broadcast_shape(op, shp1, shp2): ...@@ -34,9 +35,12 @@ def infer_broadcast_shape(op, shp1, shp2):
pad_shp1, pad_shp2 = shp1, shp2 pad_shp1, pad_shp2 = shp1, shp2
if op == "dot": if op == "dot":
if shp1[-1] != shp2[-1]: if shp1[-1] != shp2[-1]:
raise DGLError("Dot operator is only available for arrays with the " raise DGLError(
"same size on last dimension, but got {} and {}." "Dot operator is only available for arrays with the "
.format(shp1, shp2)) "same size on last dimension, but got {} and {}.".format(
shp1, shp2
)
)
if op == "copy_lhs": if op == "copy_lhs":
return shp1 return shp1
if op == "copy_rhs": if op == "copy_rhs":
...@@ -48,43 +52,44 @@ def infer_broadcast_shape(op, shp1, shp2): ...@@ -48,43 +52,44 @@ def infer_broadcast_shape(op, shp1, shp2):
pad_shp1 = (1,) * (len(shp2) - len(shp1)) + shp1 pad_shp1 = (1,) * (len(shp2) - len(shp1)) + shp1
for d1, d2 in zip(pad_shp1, pad_shp2): for d1, d2 in zip(pad_shp1, pad_shp2):
if d1 != d2 and d1 != 1 and d2 != 1: if d1 != d2 and d1 != 1 and d2 != 1:
raise DGLError("Feature shapes {} and {} are not valid for broadcasting." raise DGLError(
.format(shp1, shp2)) "Feature shapes {} and {} are not valid for broadcasting.".format(
shp1, shp2
)
)
rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)) rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2))
return rst[:-1] + (1,) if op == "dot" else rst return rst[:-1] + (1,) if op == "dot" else rst
def to_dgl_nd(x): def to_dgl_nd(x):
"""Convert framework-specific tensor/None to dgl ndarray.""" """Convert framework-specific tensor/None to dgl ndarray."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x) return nd.NULL["int64"] if x is None else F.zerocopy_to_dgl_ndarray(x)
def to_dgl_nd_for_write(x): def to_dgl_nd_for_write(x):
"""Convert framework-specific tensor/None to dgl ndarray for write.""" """Convert framework-specific tensor/None to dgl ndarray for write."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x) return (
nd.NULL["int64"]
if x is None
else F.zerocopy_to_dgl_ndarray_for_write(x)
)
def get_typeid_by_target(gidx, etid, target): def get_typeid_by_target(gidx, etid, target):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'.""" """Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
src_id, dst_id = gidx.metagraph.find_edge(etid) src_id, dst_id = gidx.metagraph.find_edge(etid)
if target in [0, 'u']: if target in [0, "u"]:
return src_id return src_id
if target in [2, 'v']: if target in [2, "v"]:
return dst_id return dst_id
return etid return etid
target_mapping = { target_mapping = {"u": 0, "e": 1, "v": 2, "src": 0, "edge": 1, "dst": 2}
'u': 0,
'e': 1,
'v': 2,
'src': 0,
'edge': 1,
'dst': 2
}
def _edge_softmax_backward(gidx, out, sds): def _edge_softmax_backward(gidx, out, sds):
r""" Edge_softmax backward interface. r"""Edge_softmax backward interface.
Parameters Parameters
---------- ----------
...@@ -103,17 +108,21 @@ def _edge_softmax_backward(gidx, out, sds): ...@@ -103,17 +108,21 @@ def _edge_softmax_backward(gidx, out, sds):
----- -----
This function does not support gpu op. This function does not support gpu op.
""" """
op = 'copy_rhs' op = "copy_rhs"
back_out = F.zeros_like(out) back_out = F.zeros_like(out)
_CAPI_DGLKernelEdge_softmax_backward(gidx, op, _CAPI_DGLKernelEdge_softmax_backward(
to_dgl_nd(out), gidx,
to_dgl_nd(sds), op,
to_dgl_nd_for_write(back_out), to_dgl_nd(out),
to_dgl_nd(None)) to_dgl_nd(sds),
to_dgl_nd_for_write(back_out),
to_dgl_nd(None),
)
return back_out return back_out
def _edge_softmax_forward(gidx, e, op): def _edge_softmax_forward(gidx, e, op):
r""" Edge_softmax forward interface. r"""Edge_softmax forward interface.
Parameters Parameters
---------- ----------
...@@ -138,15 +147,15 @@ def _edge_softmax_forward(gidx, e, op): ...@@ -138,15 +147,15 @@ def _edge_softmax_forward(gidx, e, op):
else: else:
expand = False expand = False
myout = F.zeros_like(e) myout = F.zeros_like(e)
_CAPI_DGLKernelEdge_softmax_forward(gidx, op, _CAPI_DGLKernelEdge_softmax_forward(
to_dgl_nd(None), gidx, op, to_dgl_nd(None), to_dgl_nd(e), to_dgl_nd_for_write(myout)
to_dgl_nd(e), )
to_dgl_nd_for_write(myout))
myout = F.squeeze(myout, -1) if expand else myout myout = F.squeeze(myout, -1) if expand else myout
return myout return myout
def _gspmm(gidx, op, reduce_op, u, e): def _gspmm(gidx, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of r"""Generalized Sparse Matrix Multiplication interface. It takes the result of
:attr:`op` on source node feature and edge feature, leads to a message on edge. :attr:`op` on source node feature and edge feature, leads to a message on edge.
Then aggregates the message by :attr:`reduce_op` on destination nodes. Then aggregates the message by :attr:`reduce_op` on destination nodes.
...@@ -188,13 +197,15 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -188,13 +197,15 @@ def _gspmm(gidx, op, reduce_op, u, e):
""" """
if gidx.number_of_etypes() != 1: if gidx.number_of_etypes() != 1:
raise DGLError("We only support gspmm on graph with one edge type") raise DGLError("We only support gspmm on graph with one edge type")
use_u = op != 'copy_rhs' use_u = op != "copy_rhs"
use_e = op != 'copy_lhs' use_e = op != "copy_lhs"
if use_u and use_e: if use_u and use_e:
if F.dtype(u) != F.dtype(e): if F.dtype(u) != F.dtype(e):
raise DGLError("The node features' data type {} doesn't match edge" raise DGLError(
" features' data type {}, please convert them to the" "The node features' data type {} doesn't match edge"
" same type.".format(F.dtype(u), F.dtype(e))) " features' data type {}, please convert them to the"
" same type.".format(F.dtype(u), F.dtype(e))
)
# deal with scalar features. # deal with scalar features.
expand_u, expand_e = False, False expand_u, expand_e = False, False
if use_u: if use_u:
...@@ -211,10 +222,11 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -211,10 +222,11 @@ def _gspmm(gidx, op, reduce_op, u, e):
u_shp = F.shape(u) if use_u else (0,) u_shp = F.shape(u) if use_u else (0,)
e_shp = F.shape(e) if use_e else (0,) e_shp = F.shape(e) if use_e else (0,)
_, dsttype = gidx.metagraph.find_edge(0) _, dsttype = gidx.metagraph.find_edge(0)
v_shp = (gidx.number_of_nodes(dsttype), ) +\ v_shp = (gidx.number_of_nodes(dsttype),) + infer_broadcast_shape(
infer_broadcast_shape(op, u_shp[1:], e_shp[1:]) op, u_shp[1:], e_shp[1:]
)
v = F.zeros(v_shp, dtype, ctx) v = F.zeros(v_shp, dtype, ctx)
use_cmp = reduce_op in ['max', 'min'] use_cmp = reduce_op in ["max", "min"]
arg_u, arg_e = None, None arg_u, arg_e = None, None
idtype = getattr(F, gidx.dtype) idtype = getattr(F, gidx.dtype)
if use_cmp: if use_cmp:
...@@ -225,12 +237,16 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -225,12 +237,16 @@ def _gspmm(gidx, op, reduce_op, u, e):
arg_u_nd = to_dgl_nd_for_write(arg_u) arg_u_nd = to_dgl_nd_for_write(arg_u)
arg_e_nd = to_dgl_nd_for_write(arg_e) arg_e_nd = to_dgl_nd_for_write(arg_e)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMM(gidx, op, reduce_op, _CAPI_DGLKernelSpMM(
to_dgl_nd(u if use_u else None), gidx,
to_dgl_nd(e if use_e else None), op,
to_dgl_nd_for_write(v), reduce_op,
arg_u_nd, to_dgl_nd(u if use_u else None),
arg_e_nd) to_dgl_nd(e if use_e else None),
to_dgl_nd_for_write(v),
arg_u_nd,
arg_e_nd,
)
# NOTE(zihao): actually we can avoid the following step, because arg_*_nd # NOTE(zihao): actually we can avoid the following step, because arg_*_nd
# refers to the data that stores arg_*. After we call _CAPI_DGLKernelSpMM, # refers to the data that stores arg_*. After we call _CAPI_DGLKernelSpMM,
# arg_* should have already been changed. But we found this doesn't work # arg_* should have already been changed. But we found this doesn't work
...@@ -251,7 +267,7 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -251,7 +267,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogeneous graphs. r"""Generalized Sparse Matrix Multiplication interface on heterogeneous graphs.
It handles multiple node and edge types of the graph. For each edge type, it takes It handles multiple node and edge types of the graph. For each edge type, it takes
the result of :attr:`op` on source node feature and edge feature, and leads to a the result of :attr:`op` on source node feature and edge feature, and leads to a
message on edge. Then it aggregates the message by :attr:`reduce_op` on the destination message on edge. Then it aggregates the message by :attr:`reduce_op` on the destination
...@@ -298,8 +314,8 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -298,8 +314,8 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
This function does not handle gradients. This function does not handle gradients.
""" """
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:] u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
use_u = op != 'copy_rhs' use_u = op != "copy_rhs"
use_e = op != 'copy_lhs' use_e = op != "copy_lhs"
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features. # deal with scalar features.
...@@ -319,7 +335,7 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -319,7 +335,7 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
list_arg_e_etype_nd = [None] * num_ntypes list_arg_e_etype_nd = [None] * num_ntypes
list_arg_e_etype = [None] * num_ntypes list_arg_e_etype = [None] * num_ntypes
use_cmp = reduce_op in ['max', 'min'] use_cmp = reduce_op in ["max", "min"]
idtype = getattr(F, gidx.dtype) idtype = getattr(F, gidx.dtype)
for etid in range(num_etypes): for etid in range(num_etypes):
...@@ -336,12 +352,17 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -336,12 +352,17 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
e = F.unsqueeze(e, -1) e = F.unsqueeze(e, -1)
expand_e = True expand_e = True
list_e[etid] = e if use_e else None list_e[etid] = e if use_e else None
ctx = F.context(u) if use_u else F.context(e) # TODO(Israt): Put outside of loop ctx = (
dtype = F.dtype(u) if use_u else F.dtype(e) # TODO(Israt): Put outside of loop F.context(u) if use_u else F.context(e)
) # TODO(Israt): Put outside of loop
dtype = (
F.dtype(u) if use_u else F.dtype(e)
) # TODO(Israt): Put outside of loop
u_shp = F.shape(u) if use_u else (0,) u_shp = F.shape(u) if use_u else (0,)
e_shp = F.shape(e) if use_e else (0,) e_shp = F.shape(e) if use_e else (0,)
v_shp = (gidx.number_of_nodes(dst_id), ) +\ v_shp = (gidx.number_of_nodes(dst_id),) + infer_broadcast_shape(
infer_broadcast_shape(op, u_shp[1:], e_shp[1:]) op, u_shp[1:], e_shp[1:]
)
list_v[dst_id] = F.zeros(v_shp, dtype, ctx) list_v[dst_id] = F.zeros(v_shp, dtype, ctx)
if use_cmp: if use_cmp:
if use_u: if use_u:
...@@ -351,38 +372,62 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -351,38 +372,62 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
list_arg_e[dst_id] = F.zeros(v_shp, idtype, ctx) list_arg_e[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_e_etype[dst_id] = F.zeros(v_shp, idtype, ctx) list_arg_e_etype[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_u_nd[dst_id] = to_dgl_nd_for_write(list_arg_u[dst_id]) list_arg_u_nd[dst_id] = to_dgl_nd_for_write(list_arg_u[dst_id])
list_arg_u_ntype_nd[dst_id] = to_dgl_nd_for_write(list_arg_u_ntype[dst_id]) list_arg_u_ntype_nd[dst_id] = to_dgl_nd_for_write(
list_arg_u_ntype[dst_id]
)
list_arg_e_nd[dst_id] = to_dgl_nd_for_write(list_arg_e[dst_id]) list_arg_e_nd[dst_id] = to_dgl_nd_for_write(list_arg_e[dst_id])
list_arg_e_etype_nd[dst_id] = to_dgl_nd_for_write(list_arg_e_etype[dst_id]) list_arg_e_etype_nd[dst_id] = to_dgl_nd_for_write(
list_arg_e_etype[dst_id]
)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMMHetero(gidx, op, reduce_op, _CAPI_DGLKernelSpMMHetero(
[to_dgl_nd(u_i) for u_i in list_u], gidx,
[to_dgl_nd(e_i) for e_i in list_e], op,
[to_dgl_nd_for_write(v_i) for v_i in list_v], reduce_op,
list_arg_u_nd, list_arg_e_nd, [to_dgl_nd(u_i) for u_i in list_u],
list_arg_u_ntype_nd, list_arg_e_etype_nd) [to_dgl_nd(e_i) for e_i in list_e],
[to_dgl_nd_for_write(v_i) for v_i in list_v],
list_arg_u_nd,
list_arg_e_nd,
list_arg_u_ntype_nd,
list_arg_e_etype_nd,
)
for l, arg_u_nd in enumerate(list_arg_u_nd): for l, arg_u_nd in enumerate(list_arg_u_nd):
# TODO(Israt): l or src_id as index of lhs # TODO(Israt): l or src_id as index of lhs
list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd) list_arg_u[l] = (
None
if list_arg_u[l] is None
else F.zerocopy_from_dgl_ndarray(arg_u_nd)
)
if list_arg_u[l] is not None and expand_u and use_cmp: if list_arg_u[l] is not None and expand_u and use_cmp:
list_arg_u[l] = F.squeeze(list_arg_u[l], -1) list_arg_u[l] = F.squeeze(list_arg_u[l], -1)
for l, arg_e_nd in enumerate(list_arg_e_nd): for l, arg_e_nd in enumerate(list_arg_e_nd):
list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd) list_arg_e[l] = (
None
if list_arg_e[l] is None
else F.zerocopy_from_dgl_ndarray(arg_e_nd)
)
if list_arg_e[l] is not None and expand_e and use_cmp: if list_arg_e[l] is not None and expand_e and use_cmp:
list_arg_e[l] = F.squeeze(list_arg_e[l], -1) list_arg_e[l] = F.squeeze(list_arg_e[l], -1)
for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd): for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):
list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \ list_arg_u_ntype[l] = (
None
if arg_u_ntype_nd is None
else F.zerocopy_from_dgl_ndarray(arg_u_ntype_nd) else F.zerocopy_from_dgl_ndarray(arg_u_ntype_nd)
)
for l, arg_e_etype_nd in enumerate(list_arg_e_etype_nd): for l, arg_e_etype_nd in enumerate(list_arg_e_etype_nd):
list_arg_e_etype[l] = None if arg_e_etype_nd is None \ list_arg_e_etype[l] = (
None
if arg_e_etype_nd is None
else F.zerocopy_from_dgl_ndarray(arg_e_etype_nd) else F.zerocopy_from_dgl_ndarray(arg_e_etype_nd)
)
# To deal with scalar node/edge features. # To deal with scalar node/edge features.
for l in range(num_ntypes): for l in range(num_ntypes):
# replace None by empty tensor. Forward func doesn't accept None in tuple. # replace None by empty tensor. Forward func doesn't accept None in tuple.
v = list_v[l] v = list_v[l]
v = F.tensor([]) if v is None else v v = F.tensor([]) if v is None else v
if ((expand_u or not use_u) and (expand_e or not use_e)): if (expand_u or not use_u) and (expand_e or not use_e):
v = F.squeeze(v, -1) # To deal with scalar node/edge features. v = F.squeeze(v, -1) # To deal with scalar node/edge features.
list_v[l] = v list_v[l] = v
out = tuple(list_v) out = tuple(list_v)
...@@ -391,29 +436,34 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -391,29 +436,34 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
def _segment_mm(A, B, out, seglen_A, b_trans=False): def _segment_mm(A, B, out, seglen_A, b_trans=False):
"""Invoke the C API of segment_mm.""" """Invoke the C API of segment_mm."""
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A), _CAPI_DGLKernelSEGMENTMM(
to_dgl_nd(B), to_dgl_nd(A),
to_dgl_nd_for_write(out), to_dgl_nd(B),
to_dgl_nd(seglen_A), to_dgl_nd_for_write(out),
False, b_trans) to_dgl_nd(seglen_A),
False,
b_trans,
)
return out return out
def _segment_mm_backward_B(A, dC, dB, seglen): def _segment_mm_backward_B(A, dC, dB, seglen):
"""Invoke the C API of the backward of segment_mm on B.""" """Invoke the C API of the backward of segment_mm on B."""
_CAPI_DGLKernelSEGMENTMMBackwardB( _CAPI_DGLKernelSEGMENTMMBackwardB(
to_dgl_nd(A), to_dgl_nd(A), to_dgl_nd(dC), to_dgl_nd_for_write(dB), to_dgl_nd(seglen)
to_dgl_nd(dC), )
to_dgl_nd_for_write(dB),
to_dgl_nd(seglen))
return dB return dB
def _gather_mm(A, B, out, idx_a=None, idx_b=None): def _gather_mm(A, B, out, idx_a=None, idx_b=None):
r"""Invoke the C API of the gather_mm operator.""" r"""Invoke the C API of the gather_mm operator."""
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A), _CAPI_DGLKernelGATHERMM(
to_dgl_nd(B), to_dgl_nd(A),
to_dgl_nd_for_write(out), to_dgl_nd(B),
to_dgl_nd(idx_a), to_dgl_nd_for_write(out),
to_dgl_nd(idx_b)) to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
)
return out return out
...@@ -425,12 +475,13 @@ def _gather_mm_scatter(A, B, out, idx_a=None, idx_b=None, idx_c=None): ...@@ -425,12 +475,13 @@ def _gather_mm_scatter(A, B, out, idx_a=None, idx_b=None, idx_c=None):
to_dgl_nd_for_write(out), to_dgl_nd_for_write(out),
to_dgl_nd(idx_a), to_dgl_nd(idx_a),
to_dgl_nd(idx_b), to_dgl_nd(idx_b),
to_dgl_nd(idx_c)) to_dgl_nd(idx_c),
)
return out return out
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target="u", rhs_target="v"):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node takes the result of :attr:`op` on source node feature and destination node
feature, leads to a feature on edge. feature, leads to a feature on edge.
...@@ -471,12 +522,14 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -471,12 +522,14 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
""" """
if gidx.number_of_etypes() != 1: if gidx.number_of_etypes() != 1:
raise DGLError("We only support gsddmm on graph with one edge type") raise DGLError("We only support gsddmm on graph with one edge type")
use_lhs = op != 'copy_rhs' use_lhs = op != "copy_rhs"
use_rhs = op != 'copy_lhs' use_rhs = op != "copy_lhs"
if use_lhs and use_rhs: if use_lhs and use_rhs:
if F.dtype(lhs) != F.dtype(rhs): if F.dtype(lhs) != F.dtype(rhs):
raise DGLError("The operands data type don't match: {} and {}, please convert them" raise DGLError(
" to the same type.".format(F.dtype(lhs), F.dtype(rhs))) "The operands data type don't match: {} and {}, please convert them"
" to the same type.".format(F.dtype(lhs), F.dtype(rhs))
)
# deal with scalar features. # deal with scalar features.
expand_lhs, expand_rhs = False, False expand_lhs, expand_rhs = False, False
if use_lhs: if use_lhs:
...@@ -494,35 +547,48 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -494,35 +547,48 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs) dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,) lhs_shp = F.shape(lhs) if use_lhs else (0,)
rhs_shp = F.shape(rhs) if use_rhs else (0,) rhs_shp = F.shape(rhs) if use_rhs else (0,)
out_shp = (gidx.number_of_edges(0), ) +\ out_shp = (gidx.number_of_edges(0),) + infer_broadcast_shape(
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:]) op, lhs_shp[1:], rhs_shp[1:]
)
out = F.zeros(out_shp, dtype, ctx) out = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMM(gidx, op, _CAPI_DGLKernelSDDMM(
to_dgl_nd(lhs if use_lhs else None), gidx,
to_dgl_nd(rhs if use_rhs else None), op,
to_dgl_nd_for_write(out), to_dgl_nd(lhs if use_lhs else None),
lhs_target, rhs_target) to_dgl_nd(rhs if use_rhs else None),
to_dgl_nd_for_write(out),
lhs_target,
rhs_target,
)
if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs): if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):
out = F.squeeze(out, -1) out = F.squeeze(out, -1)
return out return out
def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None): def _gsddmm_hetero(
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. gidx, op, lhs_len, lhs_target="u", rhs_target="v", lhs_and_rhs_tuple=None
""" ):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:] r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface."""
lhs_tuple, rhs_tuple = (
lhs_and_rhs_tuple[:lhs_len],
lhs_and_rhs_tuple[lhs_len:],
)
use_lhs = op != 'copy_rhs' use_lhs = op != "copy_rhs"
use_rhs = op != 'copy_lhs' use_rhs = op != "copy_lhs"
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features. # deal with scalar features.
expand_lhs, expand_rhs = False, False expand_lhs, expand_rhs = False, False
num_ntype = gidx.number_of_ntypes() num_ntype = gidx.number_of_ntypes()
num_etype = gidx.number_of_etypes() num_etype = gidx.number_of_etypes()
lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype lhs_list = (
rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype [None] * num_ntype if lhs_target in ["u", "v"] else [None] * num_etype
)
rhs_list = (
[None] * num_ntype if rhs_target in ["u", "v"] else [None] * num_etype
)
out_list = [None] * gidx.number_of_etypes() out_list = [None] * gidx.number_of_etypes()
lhs_target = target_mapping[lhs_target] lhs_target = target_mapping[lhs_target]
...@@ -547,15 +613,20 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh ...@@ -547,15 +613,20 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh
rhs_shp = F.shape(rhs) if use_rhs else (0,) rhs_shp = F.shape(rhs) if use_rhs else (0,)
lhs_list[lhs_id] = lhs if use_lhs else None lhs_list[lhs_id] = lhs if use_lhs else None
rhs_list[rhs_id] = rhs if use_rhs else None rhs_list[rhs_id] = rhs if use_rhs else None
out_shp = (gidx.number_of_edges(etid), ) +\ out_shp = (gidx.number_of_edges(etid),) + infer_broadcast_shape(
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:]) op, lhs_shp[1:], rhs_shp[1:]
)
out_list[etid] = F.zeros(out_shp, dtype, ctx) out_list[etid] = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(gidx, op, _CAPI_DGLKernelSDDMMHetero(
[to_dgl_nd(lhs) for lhs in lhs_list], gidx,
[to_dgl_nd(rhs) for rhs in rhs_list], op,
[to_dgl_nd_for_write(out) for out in out_list], [to_dgl_nd(lhs) for lhs in lhs_list],
lhs_target, rhs_target) [to_dgl_nd(rhs) for rhs in rhs_list],
[to_dgl_nd_for_write(out) for out in out_list],
lhs_target,
rhs_target,
)
for l in range(gidx.number_of_etypes()): for l in range(gidx.number_of_etypes()):
# Replace None by empty tensor. Forward func doesn't accept None in tuple. # Replace None by empty tensor. Forward func doesn't accept None in tuple.
...@@ -607,20 +678,22 @@ def _segment_reduce(op, feat, offsets): ...@@ -607,20 +678,22 @@ def _segment_reduce(op, feat, offsets):
idtype = F.dtype(offsets) idtype = F.dtype(offsets)
out = F.zeros(out_shp, dtype, ctx) out = F.zeros(out_shp, dtype, ctx)
arg = None arg = None
if op in ['min', 'max']: if op in ["min", "max"]:
arg = F.zeros(out_shp, idtype, ctx) arg = F.zeros(out_shp, idtype, ctx)
arg_nd = to_dgl_nd_for_write(arg) arg_nd = to_dgl_nd_for_write(arg)
_CAPI_DGLKernelSegmentReduce(op, _CAPI_DGLKernelSegmentReduce(
to_dgl_nd(feat), op,
to_dgl_nd(offsets), to_dgl_nd(feat),
to_dgl_nd_for_write(out), to_dgl_nd(offsets),
arg_nd) to_dgl_nd_for_write(out),
arg_nd,
)
arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd) arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd)
return out, arg return out, arg
def _scatter_add(x, idx, m): def _scatter_add(x, idx, m):
r""" Scatter add operator (on first dimension) implementation. r"""Scatter add operator (on first dimension) implementation.
Math: y[idx[i], *] += x[i, *] Math: y[idx[i], *] += x[i, *]
...@@ -642,14 +715,16 @@ def _scatter_add(x, idx, m): ...@@ -642,14 +715,16 @@ def _scatter_add(x, idx, m):
ctx = F.context(x) ctx = F.context(x)
dtype = F.dtype(x) dtype = F.dtype(x)
out = F.zeros(out_shp, dtype, ctx) out = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelScatterAdd(to_dgl_nd(x), _CAPI_DGLKernelScatterAdd(
to_dgl_nd(idx), to_dgl_nd(x), to_dgl_nd(idx), to_dgl_nd_for_write(out)
to_dgl_nd_for_write(out)) )
return out return out
def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_dX): def _update_grad_minmax_hetero(
r""" Update gradients for reduce operator max and min (on first dimension) implementation. gidx, op, list_x, list_idx, list_idx_etype, list_dX
):
r"""Update gradients for reduce operator max and min (on first dimension) implementation.
Parameters Parameters
---------- ----------
...@@ -669,11 +744,11 @@ def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_ ...@@ -669,11 +744,11 @@ def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_
Tensor Tensor
The output tensor. The output tensor.
""" """
use_u = op != 'copy_rhs' use_u = op != "copy_rhs"
use_e = op != 'copy_lhs' use_e = op != "copy_lhs"
list_out = [None] * len(list_dX) list_out = [None] * len(list_dX)
for etid in range(gidx.number_of_etypes()): for etid in range(gidx.number_of_etypes()):
src_id, dst_id = gidx.metagraph.find_edge(etid) # gidx is reveresed src_id, dst_id = gidx.metagraph.find_edge(etid) # gidx is reveresed
x = list_x[src_id] x = list_x[src_id]
ctx = F.context(x) ctx = F.context(x)
dtype = F.dtype(x) dtype = F.dtype(x)
...@@ -684,16 +759,19 @@ def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_ ...@@ -684,16 +759,19 @@ def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_
out_shp = (len(list_dX[etid]),) + F.shape(x)[1:] out_shp = (len(list_dX[etid]),) + F.shape(x)[1:]
list_out[etid] = F.zeros(out_shp, dtype, ctx) list_out[etid] = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelUpdateGradMinMaxHetero(gidx, op, _CAPI_DGLKernelUpdateGradMinMaxHetero(
[to_dgl_nd(x) for x in list_x], gidx,
[to_dgl_nd(idx) for idx in list_idx], op,
[to_dgl_nd(idx_etype) for idx_etype in list_idx_etype], [to_dgl_nd(x) for x in list_x],
[to_dgl_nd_for_write(out) for out in list_out]) [to_dgl_nd(idx) for idx in list_idx],
[to_dgl_nd(idx_etype) for idx_etype in list_idx_etype],
[to_dgl_nd_for_write(out) for out in list_out],
)
return tuple(list_out) return tuple(list_out)
def _bwd_segment_cmp(feat, arg, m): def _bwd_segment_cmp(feat, arg, m):
r""" Backward phase of segment reduction (for 'min'/'max' reduction). r"""Backward phase of segment reduction (for 'min'/'max' reduction).
It computes the gradient of input feature given output gradient of It computes the gradient of input feature given output gradient of
the segment reduction result. the segment reduction result.
...@@ -716,11 +794,12 @@ def _bwd_segment_cmp(feat, arg, m): ...@@ -716,11 +794,12 @@ def _bwd_segment_cmp(feat, arg, m):
ctx = F.context(feat) ctx = F.context(feat)
dtype = F.dtype(feat) dtype = F.dtype(feat)
out = F.zeros(out_shp, dtype, ctx) out = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelBwdSegmentCmp(to_dgl_nd(feat), _CAPI_DGLKernelBwdSegmentCmp(
to_dgl_nd(arg), to_dgl_nd(feat), to_dgl_nd(arg), to_dgl_nd_for_write(out)
to_dgl_nd_for_write(out)) )
return out return out
def _csrmm(A, A_weights, B, B_weights, num_vtypes): def _csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Return a graph whose adjacency matrix is the sparse matrix multiplication """Return a graph whose adjacency matrix is the sparse matrix multiplication
of those of two given graphs. of those of two given graphs.
...@@ -749,9 +828,11 @@ def _csrmm(A, A_weights, B, B_weights, num_vtypes): ...@@ -749,9 +828,11 @@ def _csrmm(A, A_weights, B, B_weights, num_vtypes):
The edge weights of the output graph. The edge weights of the output graph.
""" """
C, C_weights = _CAPI_DGLCSRMM( C, C_weights = _CAPI_DGLCSRMM(
A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes) A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes
)
return C, F.from_dgl_nd(C_weights) return C, F.from_dgl_nd(C_weights)
def _csrsum(As, A_weights): def _csrsum(As, A_weights):
"""Return a graph whose adjacency matrix is the sparse matrix summation """Return a graph whose adjacency matrix is the sparse matrix summation
of the given list of graphs. of the given list of graphs.
...@@ -776,6 +857,7 @@ def _csrsum(As, A_weights): ...@@ -776,6 +857,7 @@ def _csrsum(As, A_weights):
C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights]) C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights])
return C, F.from_dgl_nd(C_weights) return C, F.from_dgl_nd(C_weights)
def _csrmask(A, A_weights, B): def _csrmask(A, A_weights, B):
"""Return the weights of A at the locations identical to the sparsity pattern """Return the weights of A at the locations identical to the sparsity pattern
of B. of B.
...@@ -805,30 +887,54 @@ def _csrmask(A, A_weights, B): ...@@ -805,30 +887,54 @@ def _csrmask(A, A_weights, B):
return F.from_dgl_nd(_CAPI_DGLCSRMask(A, F.to_dgl_nd(A_weights), B)) return F.from_dgl_nd(_CAPI_DGLCSRMask(A, F.to_dgl_nd(A_weights), B))
################################################################################################### ###################################################################################################
## Libra Graph Partition ## Libra Graph Partition
def libra_vertex_cut(nc, node_degree, edgenum_unassigned, def libra_vertex_cut(
community_weights, u, v, w, out, N, N_e, dataset): nc,
node_degree,
edgenum_unassigned,
community_weights,
u,
v,
w,
out,
N,
N_e,
dataset,
):
""" """
This function invokes C/C++ code for Libra based graph partitioning. This function invokes C/C++ code for Libra based graph partitioning.
Parameter details are present in dgl/src/array/libra_partition.cc Parameter details are present in dgl/src/array/libra_partition.cc
""" """
_CAPI_DGLLibraVertexCut(nc, _CAPI_DGLLibraVertexCut(
to_dgl_nd_for_write(node_degree), nc,
to_dgl_nd_for_write(edgenum_unassigned), to_dgl_nd_for_write(node_degree),
to_dgl_nd_for_write(community_weights), to_dgl_nd_for_write(edgenum_unassigned),
to_dgl_nd(u), to_dgl_nd_for_write(community_weights),
to_dgl_nd(v), to_dgl_nd(u),
to_dgl_nd(w), to_dgl_nd(v),
to_dgl_nd_for_write(out), to_dgl_nd(w),
N, to_dgl_nd_for_write(out),
N_e, N,
dataset) N_e,
dataset,
)
def libra2dgl_build_dict(a, b, indices, ldt_key, gdt_key, gdt_value, node_map,
offset, nc, c, fsize, dataset):
def libra2dgl_build_dict(
a,
b,
indices,
ldt_key,
gdt_key,
gdt_value,
node_map,
offset,
nc,
c,
fsize,
dataset,
):
""" """
This function invokes C/C++ code for pre-processing Libra output. This function invokes C/C++ code for pre-processing Libra output.
After graph partitioning using Libra, during conversion from Libra output to DGL/DistGNN input, After graph partitioning using Libra, during conversion from Libra output to DGL/DistGNN input,
...@@ -836,25 +942,48 @@ def libra2dgl_build_dict(a, b, indices, ldt_key, gdt_key, gdt_value, node_map, ...@@ -836,25 +942,48 @@ def libra2dgl_build_dict(a, b, indices, ldt_key, gdt_key, gdt_value, node_map,
and also to create a database of the split nodes. and also to create a database of the split nodes.
Parameter details are present in dgl/src/array/libra_partition.cc Parameter details are present in dgl/src/array/libra_partition.cc
""" """
ret = _CAPI_DGLLibra2dglBuildDict(to_dgl_nd_for_write(a), ret = _CAPI_DGLLibra2dglBuildDict(
to_dgl_nd_for_write(b), to_dgl_nd_for_write(a),
to_dgl_nd_for_write(indices), to_dgl_nd_for_write(b),
to_dgl_nd_for_write(ldt_key), to_dgl_nd_for_write(indices),
to_dgl_nd_for_write(gdt_key), to_dgl_nd_for_write(ldt_key),
to_dgl_nd_for_write(gdt_value), to_dgl_nd_for_write(gdt_key),
to_dgl_nd_for_write(node_map), to_dgl_nd_for_write(gdt_value),
to_dgl_nd_for_write(offset), to_dgl_nd_for_write(node_map),
nc, to_dgl_nd_for_write(offset),
c, nc,
fsize, c,
dataset) fsize,
dataset,
)
return ret return ret
def libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key, def libra2dgl_build_adjlist(
gdt_value, node_map, lr, lrtensor, num_nodes, feat,
nc, c, feat_size, labels, trainm, testm, valm, gfeat,
glabels, gtrainm, gtestm, gvalm, feat_shape): adj,
inner_node,
ldt,
gdt_key,
gdt_value,
node_map,
lr,
lrtensor,
num_nodes,
nc,
c,
feat_size,
labels,
trainm,
testm,
valm,
glabels,
gtrainm,
gtestm,
gvalm,
feat_shape,
):
""" """
This function invokes C/C++ code for pre-processing Libra output. This function invokes C/C++ code for pre-processing Libra output.
After graph partitioning using Libra, once the local and global dictionaries are built, After graph partitioning using Libra, once the local and global dictionaries are built,
...@@ -863,30 +992,31 @@ def libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key, ...@@ -863,30 +992,31 @@ def libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key,
for each node from the input graph to the corresponding partitions. for each node from the input graph to the corresponding partitions.
Parameter details are present in dgl/src/array/libra_partition.cc Parameter details are present in dgl/src/array/libra_partition.cc
""" """
_CAPI_DGLLibra2dglBuildAdjlist(to_dgl_nd(feat), _CAPI_DGLLibra2dglBuildAdjlist(
to_dgl_nd_for_write(gfeat), to_dgl_nd(feat),
to_dgl_nd_for_write(adj), to_dgl_nd_for_write(gfeat),
to_dgl_nd_for_write(inner_node), to_dgl_nd_for_write(adj),
to_dgl_nd(ldt), to_dgl_nd_for_write(inner_node),
to_dgl_nd(gdt_key), to_dgl_nd(ldt),
to_dgl_nd(gdt_value), to_dgl_nd(gdt_key),
to_dgl_nd(node_map), to_dgl_nd(gdt_value),
to_dgl_nd_for_write(lr), to_dgl_nd(node_map),
to_dgl_nd(lrtensor), to_dgl_nd_for_write(lr),
num_nodes, to_dgl_nd(lrtensor),
nc, num_nodes,
c, nc,
feat_size, c,
to_dgl_nd(labels), feat_size,
to_dgl_nd(trainm), to_dgl_nd(labels),
to_dgl_nd(testm), to_dgl_nd(trainm),
to_dgl_nd(valm), to_dgl_nd(testm),
to_dgl_nd_for_write(glabels), to_dgl_nd(valm),
to_dgl_nd_for_write(gtrainm), to_dgl_nd_for_write(glabels),
to_dgl_nd_for_write(gtestm), to_dgl_nd_for_write(gtrainm),
to_dgl_nd_for_write(gvalm), to_dgl_nd_for_write(gtestm),
feat_shape) to_dgl_nd_for_write(gvalm),
feat_shape,
)
def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn): def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn):
...@@ -897,11 +1027,13 @@ def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn): ...@@ -897,11 +1027,13 @@ def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn):
of a node from input graph. of a node from input graph.
Parameter details are present in dgl/src/array/libra_partition.cc Parameter details are present in dgl/src/array/libra_partition.cc
""" """
_CAPI_DGLLibra2dglSetLR(to_dgl_nd(gdt_key), _CAPI_DGLLibra2dglSetLR(
to_dgl_nd(gdt_value), to_dgl_nd(gdt_key),
to_dgl_nd_for_write(lrtensor), to_dgl_nd(gdt_value),
nc, to_dgl_nd_for_write(lrtensor),
Nn) nc,
Nn,
)
_init_api("dgl.sparse") _init_api("dgl.sparse")
"""Feature storage classes for DataLoading""" """Feature storage classes for DataLoading"""
from .. import backend as F from .. import backend as F
from .base import * from .base import *
from .numpy import * from .numpy import *
# Defines the name TensorStorage # Defines the name TensorStorage
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == "pytorch":
from .pytorch_tensor import PyTorchTensorStorage as TensorStorage from .pytorch_tensor import PyTorchTensorStorage as TensorStorage
else: else:
from .tensor import BaseTensorStorage as TensorStorage from .tensor import BaseTensorStorage as TensorStorage
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
import threading import threading
STORAGE_WRAPPERS = {} STORAGE_WRAPPERS = {}
def register_storage_wrapper(type_): def register_storage_wrapper(type_):
"""Decorator that associates a type to a ``FeatureStorage`` object. """Decorator that associates a type to a ``FeatureStorage`` object."""
"""
def deco(cls): def deco(cls):
STORAGE_WRAPPERS[type_] = cls STORAGE_WRAPPERS[type_] = cls
return cls return cls
return deco return deco
def wrap_storage(storage): def wrap_storage(storage):
"""Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper`` """Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper``
decorators. decorators.
...@@ -20,11 +23,14 @@ def wrap_storage(storage): ...@@ -20,11 +23,14 @@ def wrap_storage(storage):
if isinstance(storage, type_): if isinstance(storage, type_):
return storage_cls(storage) return storage_cls(storage)
assert isinstance(storage, FeatureStorage), ( assert isinstance(
"The frame column must be a tensor or a FeatureStorage object, got {}" storage, FeatureStorage
.format(type(storage))) ), "The frame column must be a tensor or a FeatureStorage object, got {}".format(
type(storage)
)
return storage return storage
class _FuncWrapper(object): class _FuncWrapper(object):
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
...@@ -32,18 +38,21 @@ class _FuncWrapper(object): ...@@ -32,18 +38,21 @@ class _FuncWrapper(object):
def __call__(self, buf, *args): def __call__(self, buf, *args):
buf[0] = self.func(*args) buf[0] = self.func(*args)
class ThreadedFuture(object): class ThreadedFuture(object):
"""Wraps a function into a future asynchronously executed by a Python """Wraps a function into a future asynchronously executed by a Python
``threading.Thread`. The function is being executed upon instantiation of ``threading.Thread`. The function is being executed upon instantiation of
this object. this object.
""" """
def __init__(self, target, args): def __init__(self, target, args):
self.buf = [None] self.buf = [None]
thread = threading.Thread( thread = threading.Thread(
target=_FuncWrapper(target), target=_FuncWrapper(target),
args=[self.buf] + list(args), args=[self.buf] + list(args),
daemon=True) daemon=True,
)
thread.start() thread.start()
self.thread = thread self.thread = thread
...@@ -52,14 +61,15 @@ class ThreadedFuture(object): ...@@ -52,14 +61,15 @@ class ThreadedFuture(object):
self.thread.join() self.thread.join()
return self.buf[0] return self.buf[0]
class FeatureStorage(object): class FeatureStorage(object):
"""Feature storage object which should support a fetch() operation. It is the """Feature storage object which should support a fetch() operation. It is the
counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous
graphs where the keys are node/edge types. graphs where the keys are node/edge types.
""" """
def requires_ddp(self): def requires_ddp(self):
"""Whether the FeatureStorage requires the DataLoader to set use_ddp. """Whether the FeatureStorage requires the DataLoader to set use_ddp."""
"""
return False return False
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
......
"""Feature storage for ``numpy.memmap`` object.""" """Feature storage for ``numpy.memmap`` object."""
import numpy as np import numpy as np
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper
from .. import backend as F from .. import backend as F
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper
@register_storage_wrapper(np.memmap) @register_storage_wrapper(np.memmap)
class NumpyStorage(FeatureStorage): class NumpyStorage(FeatureStorage):
"""FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object.""" """FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object."""
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
...@@ -17,4 +20,6 @@ class NumpyStorage(FeatureStorage): ...@@ -17,4 +20,6 @@ class NumpyStorage(FeatureStorage):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
return ThreadedFuture(target=self._fetch, args=(indices, device, pin_memory)) return ThreadedFuture(
target=self._fetch, args=(indices, device, pin_memory)
)
"""Feature storages for PyTorch tensors.""" """Feature storages for PyTorch tensors."""
import torch import torch
from ..utils import gather_pinned_tensor_rows
from .base import register_storage_wrapper from .base import register_storage_wrapper
from .tensor import BaseTensorStorage from .tensor import BaseTensorStorage
from ..utils import gather_pinned_tensor_rows
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs): def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs):
result = torch.empty( result = torch.empty(
indices.shape[0], *feature_shape, dtype=tensor.dtype, indices.shape[0],
pin_memory=pin_memory) *feature_shape,
dtype=tensor.dtype,
pin_memory=pin_memory,
)
torch.index_select(tensor, 0, indices, out=result) torch.index_select(tensor, 0, indices, out=result)
kwargs['non_blocking'] = pin_memory kwargs["non_blocking"] = pin_memory
result = result.to(device, **kwargs) result = result.to(device, **kwargs)
return result return result
def _fetch_cuda(indices, tensor, device, **kwargs): def _fetch_cuda(indices, tensor, device, **kwargs):
return torch.index_select(tensor, 0, indices).to(device, **kwargs) return torch.index_select(tensor, 0, indices).to(device, **kwargs)
@register_storage_wrapper(torch.Tensor) @register_storage_wrapper(torch.Tensor)
class PyTorchTensorStorage(BaseTensorStorage): class PyTorchTensorStorage(BaseTensorStorage):
"""Feature storages for slicing a PyTorch tensor.""" """Feature storages for slicing a PyTorch tensor."""
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
device = torch.device(device) device = torch.device(device)
storage_device_type = self.storage.device.type storage_device_type = self.storage.device.type
indices_device_type = indices.device.type indices_device_type = indices.device.type
if storage_device_type != 'cuda': if storage_device_type != "cuda":
if indices_device_type == 'cuda': if indices_device_type == "cuda":
if self.storage.is_pinned(): if self.storage.is_pinned():
return gather_pinned_tensor_rows(self.storage, indices) return gather_pinned_tensor_rows(self.storage, indices)
else: else:
raise ValueError( raise ValueError(
f'Got indices on device {indices.device} whereas the feature tensor ' f"Got indices on device {indices.device} whereas the feature tensor "
f'is on {self.storage.device}. Please either (1) move the graph ' f"is on {self.storage.device}. Please either (1) move the graph "
f'to GPU with to() method, or (2) pin the graph with ' f"to GPU with to() method, or (2) pin the graph with "
f'pin_memory_() method.') f"pin_memory_() method."
)
# CPU to CPU or CUDA - use pin_memory and async transfer if possible # CPU to CPU or CUDA - use pin_memory and async transfer if possible
else: else:
return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device, return _fetch_cpu(
pin_memory, **kwargs) indices,
self.storage,
self.storage.shape[1:],
device,
pin_memory,
**kwargs,
)
else: else:
# CUDA to CUDA or CPU # CUDA to CUDA or CPU
return _fetch_cuda(indices, self.storage, device, **kwargs) return _fetch_cuda(indices, self.storage, device, **kwargs)
"""Feature storages for tensors across different frameworks.""" """Feature storages for tensors across different frameworks."""
from .base import FeatureStorage
from .. import backend as F from .. import backend as F
from .base import FeatureStorage
class BaseTensorStorage(FeatureStorage): class BaseTensorStorage(FeatureStorage):
"""FeatureStorage that synchronously slices features from a tensor and transfers """FeatureStorage that synchronously slices features from a tensor and transfers
it to the given device. it to the given device.
""" """
def __init__(self, tensor): def __init__(self, tensor):
self.storage = tensor self.storage = tensor
def fetch(self, indices, device, pin_memory=False, **kwargs): # pylint: disable=unused-argument def fetch(
self, indices, device, pin_memory=False, **kwargs
): # pylint: disable=unused-argument
return F.copy_to(F.gather_row(tensor, indices), device, **kwargs) return F.copy_to(F.gather_row(tensor, indices), device, **kwargs)
...@@ -5,19 +5,28 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin ...@@ -5,19 +5,28 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin
""" """
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F
from . import graph_index, heterograph_index, utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from . import backend as F
from . import graph_index
from . import heterograph_index
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import utils from .utils import context_of, recursive_apply
from .utils import recursive_apply, context_of
__all__ = [
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph', "node_subgraph",
'in_subgraph', 'out_subgraph', 'khop_in_subgraph', 'khop_out_subgraph'] "edge_subgraph",
"node_type_subgraph",
def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None): "edge_type_subgraph",
"in_subgraph",
"out_subgraph",
"khop_in_subgraph",
"khop_out_subgraph",
]
def node_subgraph(
graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None
):
"""Return a subgraph induced on the given nodes. """Return a subgraph induced on the given nodes.
A node-induced subgraph is a graph with edges whose endpoints are both in the A node-induced subgraph is a graph with edges whose endpoints are both in the
...@@ -131,36 +140,51 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_de ...@@ -131,36 +140,51 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_de
edge_subgraph edge_subgraph
""" """
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph from a block graph is not allowed.') raise DGLError("Extracting subgraph from a block graph is not allowed.")
if not isinstance(nodes, Mapping): if not isinstance(nodes, Mapping):
assert len(graph.ntypes) == 1, \ assert (
'need a dict of node type and IDs for graph with multiple node types' len(graph.ntypes) == 1
), "need a dict of node type and IDs for graph with multiple node types"
nodes = {graph.ntypes[0]: nodes} nodes = {graph.ntypes[0]: nodes}
def _process_nodes(ntype, v): def _process_nodes(ntype, v):
if F.is_tensor(v) and F.dtype(v) == F.bool: if F.is_tensor(v) and F.dtype(v) == F.bool:
return F.astype(F.nonzero_1d(F.copy_to(v, graph.device)), graph.idtype) return F.astype(
F.nonzero_1d(F.copy_to(v, graph.device)), graph.idtype
)
else: else:
return utils.prepare_tensor(graph, v, 'nodes["{}"]'.format(ntype)) return utils.prepare_tensor(graph, v, 'nodes["{}"]'.format(ntype))
nodes = {ntype: _process_nodes(ntype, v) for ntype, v in nodes.items()} nodes = {ntype: _process_nodes(ntype, v) for ntype, v in nodes.items()}
device = context_of(nodes) device = context_of(nodes)
induced_nodes = [ induced_nodes = [
nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device)) nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))
for ntype in graph.ntypes] for ntype in graph.ntypes
]
sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes) sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
# (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same # (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same
# bug in #1453. # bug in #1453.
induced_nodes_or_device = induced_nodes if relabel_nodes else device induced_nodes_or_device = induced_nodes if relabel_nodes else device
subg = _create_hetero_subgraph( subg = _create_hetero_subgraph(
graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids) graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids
)
return subg if output_device is None else subg.to(output_device) return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph) DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, output_device=None,
**deprecated_kwargs): def edge_subgraph(
graph,
edges,
*,
relabel_nodes=True,
store_ids=True,
output_device=None,
**deprecated_kwargs
):
"""Return a subgraph induced on the given edges. """Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph using the given An edge-induced subgraph is equivalent to creating a new graph using the given
...@@ -287,36 +311,47 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, output_de ...@@ -287,36 +311,47 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, output_de
""" """
if len(deprecated_kwargs) != 0: if len(deprecated_kwargs) != 0:
dgl_warning( dgl_warning(
"Key word argument preserve_nodes is deprecated. Use relabel_nodes instead.") "Key word argument preserve_nodes is deprecated. Use relabel_nodes instead."
relabel_nodes = not deprecated_kwargs.get('preserve_nodes') )
relabel_nodes = not deprecated_kwargs.get("preserve_nodes")
if graph.is_block and relabel_nodes: if graph.is_block and relabel_nodes:
raise DGLError('Extracting subgraph from a block graph is not allowed.') raise DGLError("Extracting subgraph from a block graph is not allowed.")
if not isinstance(edges, Mapping): if not isinstance(edges, Mapping):
assert len(graph.canonical_etypes) == 1, \ assert (
'need a dict of edge type and IDs for graph with multiple edge types' len(graph.canonical_etypes) == 1
), "need a dict of edge type and IDs for graph with multiple edge types"
edges = {graph.canonical_etypes[0]: edges} edges = {graph.canonical_etypes[0]: edges}
def _process_edges(etype, e): def _process_edges(etype, e):
if F.is_tensor(e) and F.dtype(e) == F.bool: if F.is_tensor(e) and F.dtype(e) == F.bool:
return F.astype(F.nonzero_1d(F.copy_to(e, graph.device)), graph.idtype) return F.astype(
F.nonzero_1d(F.copy_to(e, graph.device)), graph.idtype
)
else: else:
return utils.prepare_tensor(graph, e, 'edges["{}"]'.format(etype)) return utils.prepare_tensor(graph, e, 'edges["{}"]'.format(etype))
edges = {graph.to_canonical_etype(etype): e for etype, e in edges.items()} edges = {graph.to_canonical_etype(etype): e for etype, e in edges.items()}
edges = {etype: _process_edges(etype, e) for etype, e in edges.items()} edges = {etype: _process_edges(etype, e) for etype, e in edges.items()}
device = context_of(edges) device = context_of(edges)
induced_edges = [ induced_edges = [
edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), device)) edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), device))
for cetype in graph.canonical_etypes] for cetype in graph.canonical_etypes
]
sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes) sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
subg = _create_hetero_subgraph( subg = _create_hetero_subgraph(
graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids) graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids
)
return subg if output_device is None else subg.to(output_device) return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph) DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None):
def in_subgraph(
graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None
):
"""Return the subgraph induced on the inbound edges of all the edge types of the """Return the subgraph induced on the inbound edges of all the edge types of the
given nodes. given nodes.
...@@ -424,27 +459,37 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_dev ...@@ -424,27 +459,37 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_dev
out_subgraph out_subgraph
""" """
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError("Extracting subgraph of a block graph is not allowed.")
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(graph.ntypes) > 1: if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError(
nodes = {graph.ntypes[0] : nodes} "Must specify node type when the graph is not homogeneous."
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes') )
nodes = {graph.ntypes[0]: nodes}
nodes = utils.prepare_tensor_dict(graph, nodes, "nodes")
device = context_of(nodes) device = context_of(nodes)
nodes_all_types = [ nodes_all_types = [
F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))) F.to_dgl_nd(
for ntype in graph.ntypes] nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))
)
for ntype in graph.ntypes
]
sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes) sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
subg = _create_hetero_subgraph( subg = _create_hetero_subgraph(
graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids) graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids
)
return subg if output_device is None else subg.to(output_device) return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph) DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph)
def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None):
def out_subgraph(
graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None
):
"""Return the subgraph induced on the outbound edges of all the edge types of the """Return the subgraph induced on the outbound edges of all the edge types of the
given nodes. given nodes.
...@@ -552,27 +597,37 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_de ...@@ -552,27 +597,37 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_de
in_subgraph in_subgraph
""" """
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError("Extracting subgraph of a block graph is not allowed.")
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(graph.ntypes) > 1: if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError(
nodes = {graph.ntypes[0] : nodes} "Must specify node type when the graph is not homogeneous."
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes') )
nodes = {graph.ntypes[0]: nodes}
nodes = utils.prepare_tensor_dict(graph, nodes, "nodes")
device = context_of(nodes) device = context_of(nodes)
nodes_all_types = [ nodes_all_types = [
F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))) F.to_dgl_nd(
for ntype in graph.ntypes] nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))
)
for ntype in graph.ntypes
]
sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes) sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
subg = _create_hetero_subgraph( subg = _create_hetero_subgraph(
graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids) graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids
)
return subg if output_device is None else subg.to(output_device) return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph) DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph)
def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None):
def khop_in_subgraph(
graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None
):
"""Return the subgraph induced by k-hop in-neighborhood of the specified node(s). """Return the subgraph induced by k-hop in-neighborhood of the specified node(s).
We can expand a set of nodes by including the predecessors of them. From a We can expand a set of nodes by including the predecessors of them. From a
...@@ -677,16 +732,19 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out ...@@ -677,16 +732,19 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out
khop_out_subgraph khop_out_subgraph
""" """
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError("Extracting subgraph of a block graph is not allowed.")
is_mapping = isinstance(nodes, Mapping) is_mapping = isinstance(nodes, Mapping)
if not is_mapping: if not is_mapping:
assert len(graph.ntypes) == 1, \ assert (
'need a dict of node type and IDs for graph with multiple node types' len(graph.ntypes) == 1
), "need a dict of node type and IDs for graph with multiple node types"
nodes = {graph.ntypes[0]: nodes} nodes = {graph.ntypes[0]: nodes}
for nty, nty_nodes in nodes.items(): for nty, nty_nodes in nodes.items():
nodes[nty] = utils.prepare_tensor(graph, nty_nodes, 'nodes["{}"]'.format(nty)) nodes[nty] = utils.prepare_tensor(
graph, nty_nodes, 'nodes["{}"]'.format(nty)
)
last_hop_nodes = nodes last_hop_nodes = nodes
k_hop_nodes_ = [last_hop_nodes] k_hop_nodes_ = [last_hop_nodes]
...@@ -696,24 +754,37 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out ...@@ -696,24 +754,37 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out
current_hop_nodes = {nty: [] for nty in graph.ntypes} current_hop_nodes = {nty: [] for nty in graph.ntypes}
for cetype in graph.canonical_etypes: for cetype in graph.canonical_etypes:
srctype, _, dsttype = cetype srctype, _, dsttype = cetype
in_nbrs, _ = graph.in_edges(last_hop_nodes.get(dsttype, place_holder), etype=cetype) in_nbrs, _ = graph.in_edges(
last_hop_nodes.get(dsttype, place_holder), etype=cetype
)
current_hop_nodes[srctype].append(in_nbrs) current_hop_nodes[srctype].append(in_nbrs)
for nty in graph.ntypes: for nty in graph.ntypes:
if len(current_hop_nodes[nty]) == 0: if len(current_hop_nodes[nty]) == 0:
current_hop_nodes[nty] = place_holder current_hop_nodes[nty] = place_holder
continue continue
current_hop_nodes[nty] = F.unique(F.cat(current_hop_nodes[nty], dim=0)) current_hop_nodes[nty] = F.unique(
F.cat(current_hop_nodes[nty], dim=0)
)
k_hop_nodes_.append(current_hop_nodes) k_hop_nodes_.append(current_hop_nodes)
last_hop_nodes = current_hop_nodes last_hop_nodes = current_hop_nodes
k_hop_nodes = dict() k_hop_nodes = dict()
inverse_indices = dict() inverse_indices = dict()
for nty in graph.ntypes: for nty in graph.ntypes:
k_hop_nodes[nty], inverse_indices[nty] = F.unique(F.cat([ k_hop_nodes[nty], inverse_indices[nty] = F.unique(
hop_nodes.get(nty, place_holder) F.cat(
for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True) [
hop_nodes.get(nty, place_holder)
sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) for hop_nodes in k_hop_nodes_
],
dim=0,
),
return_inverse=True,
)
sub_g = node_subgraph(
graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids
)
if output_device is not None: if output_device is not None:
sub_g = sub_g.to(output_device) sub_g = sub_g.to(output_device)
if relabel_nodes: if relabel_nodes:
...@@ -721,20 +792,27 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out ...@@ -721,20 +792,27 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out
seed_inverse_indices = dict() seed_inverse_indices = dict()
for nty in nodes: for nty in nodes:
seed_inverse_indices[nty] = F.slice_axis( seed_inverse_indices[nty] = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])
)
else: else:
seed_inverse_indices = F.slice_axis( seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])
)
if output_device is not None: if output_device is not None:
seed_inverse_indices = recursive_apply( seed_inverse_indices = recursive_apply(
seed_inverse_indices, lambda x: F.copy_to(x, output_device)) seed_inverse_indices, lambda x: F.copy_to(x, output_device)
)
return sub_g, seed_inverse_indices return sub_g, seed_inverse_indices
else: else:
return sub_g return sub_g
DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph) DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph)
def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None):
def khop_out_subgraph(
graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None
):
"""Return the subgraph induced by k-hop out-neighborhood of the specified node(s). """Return the subgraph induced by k-hop out-neighborhood of the specified node(s).
We can expand a set of nodes by including the successors of them. From a We can expand a set of nodes by including the successors of them. From a
...@@ -839,16 +917,19 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou ...@@ -839,16 +917,19 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou
khop_in_subgraph khop_in_subgraph
""" """
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError("Extracting subgraph of a block graph is not allowed.")
is_mapping = isinstance(nodes, Mapping) is_mapping = isinstance(nodes, Mapping)
if not is_mapping: if not is_mapping:
assert len(graph.ntypes) == 1, \ assert (
'need a dict of node type and IDs for graph with multiple node types' len(graph.ntypes) == 1
), "need a dict of node type and IDs for graph with multiple node types"
nodes = {graph.ntypes[0]: nodes} nodes = {graph.ntypes[0]: nodes}
for nty, nty_nodes in nodes.items(): for nty, nty_nodes in nodes.items():
nodes[nty] = utils.prepare_tensor(graph, nty_nodes, 'nodes["{}"]'.format(nty)) nodes[nty] = utils.prepare_tensor(
graph, nty_nodes, 'nodes["{}"]'.format(nty)
)
last_hop_nodes = nodes last_hop_nodes = nodes
k_hop_nodes_ = [last_hop_nodes] k_hop_nodes_ = [last_hop_nodes]
...@@ -858,25 +939,37 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou ...@@ -858,25 +939,37 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou
current_hop_nodes = {nty: [] for nty in graph.ntypes} current_hop_nodes = {nty: [] for nty in graph.ntypes}
for cetype in graph.canonical_etypes: for cetype in graph.canonical_etypes:
srctype, _, dsttype = cetype srctype, _, dsttype = cetype
_, out_nbrs = graph.out_edges(last_hop_nodes.get( _, out_nbrs = graph.out_edges(
srctype, place_holder), etype=cetype) last_hop_nodes.get(srctype, place_holder), etype=cetype
)
current_hop_nodes[dsttype].append(out_nbrs) current_hop_nodes[dsttype].append(out_nbrs)
for nty in graph.ntypes: for nty in graph.ntypes:
if len(current_hop_nodes[nty]) == 0: if len(current_hop_nodes[nty]) == 0:
current_hop_nodes[nty] = place_holder current_hop_nodes[nty] = place_holder
continue continue
current_hop_nodes[nty] = F.unique(F.cat(current_hop_nodes[nty], dim=0)) current_hop_nodes[nty] = F.unique(
F.cat(current_hop_nodes[nty], dim=0)
)
k_hop_nodes_.append(current_hop_nodes) k_hop_nodes_.append(current_hop_nodes)
last_hop_nodes = current_hop_nodes last_hop_nodes = current_hop_nodes
k_hop_nodes = dict() k_hop_nodes = dict()
inverse_indices = dict() inverse_indices = dict()
for nty in graph.ntypes: for nty in graph.ntypes:
k_hop_nodes[nty], inverse_indices[nty] = F.unique(F.cat([ k_hop_nodes[nty], inverse_indices[nty] = F.unique(
hop_nodes.get(nty, place_holder) F.cat(
for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True) [
hop_nodes.get(nty, place_holder)
sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) for hop_nodes in k_hop_nodes_
],
dim=0,
),
return_inverse=True,
)
sub_g = node_subgraph(
graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids
)
if output_device is not None: if output_device is not None:
sub_g = sub_g.to(output_device) sub_g = sub_g.to(output_device)
if relabel_nodes: if relabel_nodes:
...@@ -884,19 +977,24 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou ...@@ -884,19 +977,24 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou
seed_inverse_indices = dict() seed_inverse_indices = dict()
for nty in nodes: for nty in nodes:
seed_inverse_indices[nty] = F.slice_axis( seed_inverse_indices[nty] = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])
)
else: else:
seed_inverse_indices = F.slice_axis( seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])
)
if output_device is not None: if output_device is not None:
seed_inverse_indices = recursive_apply( seed_inverse_indices = recursive_apply(
seed_inverse_indices, lambda x: F.copy_to(x, output_device)) seed_inverse_indices, lambda x: F.copy_to(x, output_device)
)
return sub_g, seed_inverse_indices return sub_g, seed_inverse_indices
else: else:
return sub_g return sub_g
DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph) DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph)
def node_type_subgraph(graph, ntypes, output_device=None): def node_type_subgraph(graph, ntypes, output_device=None):
"""Return the subgraph induced on given node types. """Return the subgraph induced on given node types.
...@@ -964,18 +1062,20 @@ def node_type_subgraph(graph, ntypes, output_device=None): ...@@ -964,18 +1062,20 @@ def node_type_subgraph(graph, ntypes, output_device=None):
edge_type_subgraph edge_type_subgraph
""" """
ntid = [graph.get_ntype_id(ntype) for ntype in ntypes] ntid = [graph.get_ntype_id(ntype) for ntype in ntypes]
stids, dtids, etids = graph._graph.metagraph.edges('eid') stids, dtids, etids = graph._graph.metagraph.edges("eid")
stids, dtids, etids = stids.tonumpy(), dtids.tonumpy(), etids.tonumpy() stids, dtids, etids = stids.tonumpy(), dtids.tonumpy(), etids.tonumpy()
etypes = [] etypes = []
for stid, dtid, etid in zip(stids, dtids, etids): for stid, dtid, etid in zip(stids, dtids, etids):
if stid in ntid and dtid in ntid: if stid in ntid and dtid in ntid:
etypes.append(graph.canonical_etypes[etid]) etypes.append(graph.canonical_etypes[etid])
if len(etypes) == 0: if len(etypes) == 0:
raise DGLError('There are no edges among nodes of the specified types.') raise DGLError("There are no edges among nodes of the specified types.")
return edge_type_subgraph(graph, etypes, output_device=output_device) return edge_type_subgraph(graph, etypes, output_device=output_device)
DGLHeteroGraph.node_type_subgraph = utils.alias_func(node_type_subgraph) DGLHeteroGraph.node_type_subgraph = utils.alias_func(node_type_subgraph)
def edge_type_subgraph(graph, etypes, output_device=None): def edge_type_subgraph(graph, etypes, output_device=None):
"""Return the subgraph induced on given edge types. """Return the subgraph induced on given edge types.
...@@ -1050,7 +1150,9 @@ def edge_type_subgraph(graph, etypes, output_device=None): ...@@ -1050,7 +1150,9 @@ def edge_type_subgraph(graph, etypes, output_device=None):
""" """
etype_ids = [graph.get_etype_id(etype) for etype in etypes] etype_ids = [graph.get_etype_id(etype) for etype in etypes]
# meta graph is homogeneous graph, still using int64 # meta graph is homogeneous graph, still using int64
meta_src, meta_dst, _ = graph._graph.metagraph.find_edges(utils.toindex(etype_ids, "int64")) meta_src, meta_dst, _ = graph._graph.metagraph.find_edges(
utils.toindex(etype_ids, "int64")
)
rel_graphs = [graph._graph.get_relation_graph(i) for i in etype_ids] rel_graphs = [graph._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy() meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy() meta_dst = meta_dst.tonumpy()
...@@ -1060,22 +1162,40 @@ def edge_type_subgraph(graph, etypes, output_device=None): ...@@ -1060,22 +1162,40 @@ def edge_type_subgraph(graph, etypes, output_device=None):
node_frames = [graph._node_frames[i] for i in ntypes_invmap] node_frames = [graph._node_frames[i] for i in ntypes_invmap]
edge_frames = [graph._edge_frames[i] for i in etype_ids] edge_frames = [graph._edge_frames[i] for i in etype_ids]
induced_ntypes = [graph._ntypes[i] for i in ntypes_invmap] induced_ntypes = [graph._ntypes[i] for i in ntypes_invmap]
induced_etypes = [graph._etypes[i] for i in etype_ids] # get the "name" of edge type induced_etypes = [
num_nodes_per_induced_type = [graph.number_of_nodes(ntype) for ntype in induced_ntypes] graph._etypes[i] for i in etype_ids
] # get the "name" of edge type
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True) num_nodes_per_induced_type = [
graph.number_of_nodes(ntype) for ntype in induced_ntypes
]
metagraph = graph_index.from_edge_list(
(mapped_meta_src, mapped_meta_dst), True
)
# num_nodes_per_type should be int64 # num_nodes_per_type should be int64
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64")) metagraph,
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) rel_graphs,
utils.toindex(num_nodes_per_induced_type, "int64"),
)
hg = DGLHeteroGraph(
hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames
)
return hg if output_device is None else hg.to(output_device) return hg if output_device is None else hg.to(output_device)
DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph) DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph)
#################### Internal functions #################### #################### Internal functions ####################
def _create_hetero_subgraph(parent, sgi, induced_nodes_or_device, induced_edges_or_device,
store_ids=True): def _create_hetero_subgraph(
parent,
sgi,
induced_nodes_or_device,
induced_edges_or_device,
store_ids=True,
):
"""Internal function to create a subgraph. """Internal function to create a subgraph.
Parameters Parameters
...@@ -1107,10 +1227,15 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes_or_device, induced_edges_ ...@@ -1107,10 +1227,15 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes_or_device, induced_edges_
# UVA subgraphing, where the node features are not sliced but the device changed. # UVA subgraphing, where the node features are not sliced but the device changed.
# Not having this will give us a subgraph on GPU but node features on CPU if we don't # Not having this will give us a subgraph on GPU but node features on CPU if we don't
# relabel the nodes. # relabel the nodes.
node_frames = utils.extract_node_subframes(parent, induced_nodes_or_device, store_ids) node_frames = utils.extract_node_subframes(
edge_frames = utils.extract_edge_subframes(parent, induced_edges_or_device, store_ids) parent, induced_nodes_or_device, store_ids
)
edge_frames = utils.extract_edge_subframes(
parent, induced_edges_or_device, store_ids
)
hsg = DGLHeteroGraph(sgi.graph, parent.ntypes, parent.etypes) hsg = DGLHeteroGraph(sgi.graph, parent.ntypes, parent.etypes)
utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames) utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames)
return hsg return hsg
_init_api("dgl.subgraph") _init_api("dgl.subgraph")
"""Module for graph traversal methods.""" """Module for graph traversal methods."""
from __future__ import absolute_import from __future__ import absolute_import
from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils from . import utils
from ._ffi.function import _init_api
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
__all__ = ['bfs_nodes_generator', 'bfs_edges_generator', __all__ = [
'topological_nodes_generator', "bfs_nodes_generator",
'dfs_edges_generator', 'dfs_labeled_edges_generator',] "bfs_edges_generator",
"topological_nodes_generator",
"dfs_edges_generator",
"dfs_labeled_edges_generator",
]
def bfs_nodes_generator(graph, source, reverse=False): def bfs_nodes_generator(graph, source, reverse=False):
"""Node frontiers generator using breadth-first search. """Node frontiers generator using breadth-first search.
...@@ -40,10 +45,12 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -40,10 +45,12 @@ def bfs_nodes_generator(graph, source, reverse=False):
>>> list(dgl.bfs_nodes_generator(g, 0)) >>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])] [tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'bfs_nodes_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "bfs_nodes_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -54,6 +61,7 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -54,6 +61,7 @@ def bfs_nodes_generator(graph, source, reverse=False):
node_frontiers = F.split(all_nodes, sections, dim=0) node_frontiers = F.split(all_nodes, sections, dim=0)
return node_frontiers return node_frontiers
def bfs_edges_generator(graph, source, reverse=False): def bfs_edges_generator(graph, source, reverse=False):
"""Edges frontiers generator using breadth-first search. """Edges frontiers generator using breadth-first search.
...@@ -85,10 +93,12 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -85,10 +93,12 @@ def bfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.bfs_edges_generator(g, 0)) >>> list(dgl.bfs_edges_generator(g, 0))
[tensor([0]), tensor([1, 2]), tensor([4, 5])] [tensor([0]), tensor([1, 2]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'bfs_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "bfs_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -99,6 +109,7 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -99,6 +109,7 @@ def bfs_edges_generator(graph, source, reverse=False):
edge_frontiers = F.split(all_edges, sections, dim=0) edge_frontiers = F.split(all_edges, sections, dim=0)
return edge_frontiers return edge_frontiers
def topological_nodes_generator(graph, reverse=False): def topological_nodes_generator(graph, reverse=False):
"""Node frontiers generator using topological traversal. """Node frontiers generator using topological traversal.
...@@ -127,10 +138,12 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -127,10 +138,12 @@ def topological_nodes_generator(graph, reverse=False):
>>> list(dgl.topological_nodes_generator(g)) >>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])] [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'topological_nodes_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "topological_nodes_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse) ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse)
...@@ -139,6 +152,7 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -139,6 +152,7 @@ def topological_nodes_generator(graph, reverse=False):
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0) return F.split(all_nodes, sections, dim=0)
def dfs_edges_generator(graph, source, reverse=False): def dfs_edges_generator(graph, source, reverse=False):
"""Edge frontiers generator using depth-first-search (DFS). """Edge frontiers generator using depth-first-search (DFS).
...@@ -176,10 +190,12 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -176,10 +190,12 @@ def dfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.dfs_edges_generator(g, 0)) >>> list(dgl.dfs_edges_generator(g, 0))
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])] [tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'dfs_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "dfs_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -189,13 +205,15 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -189,13 +205,15 @@ def dfs_edges_generator(graph, source, reverse=False):
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
def dfs_labeled_edges_generator( def dfs_labeled_edges_generator(
graph, graph,
source, source,
reverse=False, reverse=False,
has_reverse_edge=False, has_reverse_edge=False,
has_nontree_edge=False, has_nontree_edge=False,
return_labels=True): return_labels=True,
):
"""Produce edges in a depth-first-search (DFS) labeled by type. """Produce edges in a depth-first-search (DFS) labeled by type.
There are three labels: FORWARD(0), REVERSE(1), NONTREE(2) There are three labels: FORWARD(0), REVERSE(1), NONTREE(2)
...@@ -252,10 +270,12 @@ def dfs_labeled_edges_generator( ...@@ -252,10 +270,12 @@ def dfs_labeled_edges_generator(
(tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])), (tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2])) (tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'dfs_labeled_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "dfs_labeled_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -265,16 +285,20 @@ def dfs_labeled_edges_generator( ...@@ -265,16 +285,20 @@ def dfs_labeled_edges_generator(
reverse, reverse,
has_reverse_edge, has_reverse_edge,
has_nontree_edge, has_nontree_edge,
return_labels) return_labels,
)
all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor() all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
if return_labels: if return_labels:
all_labels = utils.toindex(ret(1)).tousertensor() all_labels = utils.toindex(ret(1)).tousertensor()
sections = utils.toindex(ret(2)).tonumpy().tolist() sections = utils.toindex(ret(2)).tonumpy().tolist()
return (F.split(all_edges, sections, dim=0), return (
F.split(all_labels, sections, dim=0)) F.split(all_edges, sections, dim=0),
F.split(all_labels, sections, dim=0),
)
else: else:
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
_init_api("dgl.traversal") _init_api("dgl.traversal")
"""Internal utilities.""" """Internal utilities."""
from .internal import *
from .data import *
from .checks import * from .checks import *
from .shared_mem import * from .data import *
from .filter import *
from .exception import * from .exception import *
from .filter import *
from .internal import *
from .pin_memory import * from .pin_memory import *
from .shared_mem import *
"""Checking and logging utilities.""" """Checking and logging utilities."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import, division from __future__ import absolute_import, division
from collections.abc import Mapping from collections.abc import Mapping
from ..base import DGLError
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
from ..base import DGLError
def prepare_tensor(g, data, name): def prepare_tensor(g, data, name):
"""Convert the data to ID tensor and check its ID type and context. """Convert the data to ID tensor and check its ID type and context.
...@@ -31,27 +33,43 @@ def prepare_tensor(g, data, name): ...@@ -31,27 +33,43 @@ def prepare_tensor(g, data, name):
""" """
if F.is_tensor(data): if F.is_tensor(data):
if F.dtype(data) != g.idtype: if F.dtype(data) != g.idtype:
raise DGLError(f'Expect argument "{name}" to have data type {g.idtype}. ' raise DGLError(
f'But got {F.dtype(data)}.') f'Expect argument "{name}" to have data type {g.idtype}. '
f"But got {F.dtype(data)}."
)
if F.context(data) != g.device and not g.is_pinned(): if F.context(data) != g.device and not g.is_pinned():
raise DGLError(f'Expect argument "{name}" to have device {g.device}. ' raise DGLError(
f'But got {F.context(data)}.') f'Expect argument "{name}" to have device {g.device}. '
f"But got {F.context(data)}."
)
ret = data ret = data
else: else:
data = F.tensor(data) data = F.tensor(data)
if (not (F.ndim(data) > 0 and F.shape(data)[0] == 0) and # empty tensor if not (
F.dtype(data) not in (F.int32, F.int64)): F.ndim(data) > 0 and F.shape(data)[0] == 0
raise DGLError('Expect argument "{}" to have data type int32 or int64,' ) and F.dtype( # empty tensor
' but got {}.'.format(name, F.dtype(data))) data
) not in (
F.int32,
F.int64,
):
raise DGLError(
'Expect argument "{}" to have data type int32 or int64,'
" but got {}.".format(name, F.dtype(data))
)
ret = F.copy_to(F.astype(data, g.idtype), g.device) ret = F.copy_to(F.astype(data, g.idtype), g.device)
if F.ndim(ret) == 0: if F.ndim(ret) == 0:
ret = F.unsqueeze(ret, 0) ret = F.unsqueeze(ret, 0)
if F.ndim(ret) > 1: if F.ndim(ret) > 1:
raise DGLError('Expect a 1-D tensor for argument "{}". But got {}.'.format( raise DGLError(
name, ret)) 'Expect a 1-D tensor for argument "{}". But got {}.'.format(
name, ret
)
)
return ret return ret
def prepare_tensor_dict(g, data, name): def prepare_tensor_dict(g, data, name):
"""Convert a dictionary of data to a dictionary of ID tensors. """Convert a dictionary of data to a dictionary of ID tensors.
...@@ -70,8 +88,11 @@ def prepare_tensor_dict(g, data, name): ...@@ -70,8 +88,11 @@ def prepare_tensor_dict(g, data, name):
------- -------
dict[str, tensor] dict[str, tensor]
""" """
return {key : prepare_tensor(g, val, '{}["{}"]'.format(name, key)) return {
for key, val in data.items()} key: prepare_tensor(g, val, '{}["{}"]'.format(name, key))
for key, val in data.items()
}
def prepare_tensor_or_dict(g, data, name): def prepare_tensor_or_dict(g, data, name):
"""Convert data to either a tensor or a dictionary depending on input type. """Convert data to either a tensor or a dictionary depending on input type.
...@@ -89,10 +110,14 @@ def prepare_tensor_or_dict(g, data, name): ...@@ -89,10 +110,14 @@ def prepare_tensor_or_dict(g, data, name):
------- -------
tensor or dict[str, tensor] tensor or dict[str, tensor]
""" """
return prepare_tensor_dict(g, data, name) if isinstance(data, Mapping) \ return (
else prepare_tensor(g, data, name) prepare_tensor_dict(g, data, name)
if isinstance(data, Mapping)
else prepare_tensor(g, data, name)
)
def parse_edges_arg_to_eid(g, edges, etid, argname='edges'):
def parse_edges_arg_to_eid(g, edges, etid, argname="edges"):
"""Parse the :attr:`edges` argument and return an edge ID tensor. """Parse the :attr:`edges` argument and return an edge ID tensor.
The resulting edge ID tensor has the same ID type and device of :attr:`g`. The resulting edge ID tensor has the same ID type and device of :attr:`g`.
...@@ -115,13 +140,14 @@ def parse_edges_arg_to_eid(g, edges, etid, argname='edges'): ...@@ -115,13 +140,14 @@ def parse_edges_arg_to_eid(g, edges, etid, argname='edges'):
""" """
if isinstance(edges, tuple): if isinstance(edges, tuple):
u, v = edges u, v = edges
u = prepare_tensor(g, u, '{}[0]'.format(argname)) u = prepare_tensor(g, u, "{}[0]".format(argname))
v = prepare_tensor(g, v, '{}[1]'.format(argname)) v = prepare_tensor(g, v, "{}[1]".format(argname))
eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid]) eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid])
else: else:
eid = prepare_tensor(g, edges, argname) eid = prepare_tensor(g, edges, argname)
return eid return eid
def check_all_same_idtype(glist, name): def check_all_same_idtype(glist, name):
"""Check all the graphs have the same idtype.""" """Check all the graphs have the same idtype."""
if len(glist) == 0: if len(glist) == 0:
...@@ -129,8 +155,12 @@ def check_all_same_idtype(glist, name): ...@@ -129,8 +155,12 @@ def check_all_same_idtype(glist, name):
idtype = glist[0].idtype idtype = glist[0].idtype
for i, g in enumerate(glist): for i, g in enumerate(glist):
if g.idtype != idtype: if g.idtype != idtype:
raise DGLError('Expect {}[{}] to have {} type ID, but got {}.'.format( raise DGLError(
name, i, idtype, g.idtype)) "Expect {}[{}] to have {} type ID, but got {}.".format(
name, i, idtype, g.idtype
)
)
def check_device(data, device): def check_device(data, device):
"""Check if data is on the target device. """Check if data is on the target device.
...@@ -152,6 +182,7 @@ def check_device(data, device): ...@@ -152,6 +182,7 @@ def check_device(data, device):
return False return False
return True return True
def check_all_same_device(glist, name): def check_all_same_device(glist, name):
"""Check all the graphs have the same device.""" """Check all the graphs have the same device."""
if len(glist) == 0: if len(glist) == 0:
...@@ -159,8 +190,12 @@ def check_all_same_device(glist, name): ...@@ -159,8 +190,12 @@ def check_all_same_device(glist, name):
device = glist[0].device device = glist[0].device
for i, g in enumerate(glist): for i, g in enumerate(glist):
if g.device != device: if g.device != device:
raise DGLError('Expect {}[{}] to be on device {}, but got {}.'.format( raise DGLError(
name, i, device, g.device)) "Expect {}[{}] to be on device {}, but got {}.".format(
name, i, device, g.device
)
)
def check_all_same_schema(schemas, name): def check_all_same_schema(schemas, name):
"""Check the list of schemas are the same.""" """Check the list of schemas are the same."""
...@@ -170,9 +205,12 @@ def check_all_same_schema(schemas, name): ...@@ -170,9 +205,12 @@ def check_all_same_schema(schemas, name):
for i, schema in enumerate(schemas): for i, schema in enumerate(schemas):
if schema != schemas[0]: if schema != schemas[0]:
raise DGLError( raise DGLError(
'Expect all graphs to have the same schema on {}, ' "Expect all graphs to have the same schema on {}, "
'but graph {} got\n\t{}\nwhich is different from\n\t{}.'.format( "but graph {} got\n\t{}\nwhich is different from\n\t{}.".format(
name, i, schema, schemas[0])) name, i, schema, schemas[0]
)
)
def check_all_same_schema_for_keys(schemas, keys, name): def check_all_same_schema_for_keys(schemas, keys, name):
"""Check the list of schemas are the same on the given keys.""" """Check the list of schemas are the same on the given keys."""
...@@ -184,9 +222,9 @@ def check_all_same_schema_for_keys(schemas, keys, name): ...@@ -184,9 +222,9 @@ def check_all_same_schema_for_keys(schemas, keys, name):
for i, schema in enumerate(schemas): for i, schema in enumerate(schemas):
if not keys.issubset(schema.keys()): if not keys.issubset(schema.keys()):
raise DGLError( raise DGLError(
'Expect all graphs to have keys {} on {}, ' "Expect all graphs to have keys {} on {}, "
'but graph {} got keys {}.'.format( "but graph {} got keys {}.".format(keys, name, i, schema.keys())
keys, name, i, schema.keys())) )
if head is None: if head is None:
head = {k: schema[k] for k in keys} head = {k: schema[k] for k in keys}
...@@ -194,9 +232,12 @@ def check_all_same_schema_for_keys(schemas, keys, name): ...@@ -194,9 +232,12 @@ def check_all_same_schema_for_keys(schemas, keys, name):
target = {k: schema[k] for k in keys} target = {k: schema[k] for k in keys}
if target != head: if target != head:
raise DGLError( raise DGLError(
'Expect all graphs to have the same schema for keys {} on {}, ' "Expect all graphs to have the same schema for keys {} on {}, "
'but graph {} got \n\t{}\n which is different from\n\t{}.'.format( "but graph {} got \n\t{}\n which is different from\n\t{}.".format(
keys, name, i, target, head)) keys, name, i, target, head
)
)
def check_valid_idtype(idtype): def check_valid_idtype(idtype):
"""Check whether the value of the idtype argument is valid (int32/int64) """Check whether the value of the idtype argument is valid (int32/int64)
...@@ -207,8 +248,11 @@ def check_valid_idtype(idtype): ...@@ -207,8 +248,11 @@ def check_valid_idtype(idtype):
The framework object of a data type. The framework object of a data type.
""" """
if idtype not in [None, F.int32, F.int64]: if idtype not in [None, F.int32, F.int64]:
raise DGLError('Expect idtype to be a framework object of int32/int64, ' raise DGLError(
'got {}'.format(idtype)) "Expect idtype to be a framework object of int32/int64, "
"got {}".format(idtype)
)
def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
"""Checks whether an edge list is in ascending src-major order (e.g., first """Checks whether an edge list is in ascending src-major order (e.g., first
...@@ -234,9 +278,9 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): ...@@ -234,9 +278,9 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
# for some versions of MXNET and TensorFlow, num_src and num_dst get # for some versions of MXNET and TensorFlow, num_src and num_dst get
# incorrectly marked as floats, so force them as integers here # incorrectly marked as floats, so force them as integers here
if num_src is None: if num_src is None:
num_src = int(F.as_scalar(F.max(src, dim=0)+1)) num_src = int(F.as_scalar(F.max(src, dim=0) + 1))
if num_dst is None: if num_dst is None:
num_dst = int(F.as_scalar(F.max(dst, dim=0)+1)) num_dst = int(F.as_scalar(F.max(dst, dim=0) + 1))
src = F.zerocopy_to_dgl_ndarray(src) src = F.zerocopy_to_dgl_ndarray(src)
dst = F.zerocopy_to_dgl_ndarray(dst) dst = F.zerocopy_to_dgl_ndarray(dst)
...@@ -247,4 +291,5 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): ...@@ -247,4 +291,5 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
return row_sorted, col_sorted return row_sorted, col_sorted
_init_api("dgl.utils.checks") _init_api("dgl.utils.checks")
"""Data utilities.""" """Data utilities."""
from collections import namedtuple from collections import namedtuple
import scipy as sp
import networkx as nx import networkx as nx
import scipy as sp
from ..base import DGLError
from .. import backend as F from .. import backend as F
from ..base import DGLError
from . import checks from . import checks
def elist2tensor(elist, idtype): def elist2tensor(elist, idtype):
"""Function to convert an edge list to edge tensors. """Function to convert an edge list to edge tensors.
...@@ -31,6 +33,7 @@ def elist2tensor(elist, idtype): ...@@ -31,6 +33,7 @@ def elist2tensor(elist, idtype):
v = list(v) v = list(v)
return F.tensor(u, idtype), F.tensor(v, idtype) return F.tensor(u, idtype), F.tensor(v, idtype)
def scipy2tensor(spmat, idtype): def scipy2tensor(spmat, idtype):
"""Function to convert a scipy matrix to a sparse adjacency matrix tuple. """Function to convert a scipy matrix to a sparse adjacency matrix tuple.
...@@ -49,7 +52,7 @@ def scipy2tensor(spmat, idtype): ...@@ -49,7 +52,7 @@ def scipy2tensor(spmat, idtype):
A tuple containing the format as well as the list of tensors representing A tuple containing the format as well as the list of tensors representing
the sparse matrix. the sparse matrix.
""" """
if spmat.format in ['csr', 'csc']: if spmat.format in ["csr", "csc"]:
indptr = F.tensor(spmat.indptr, idtype) indptr = F.tensor(spmat.indptr, idtype)
indices = F.tensor(spmat.indices, idtype) indices = F.tensor(spmat.indices, idtype)
data = F.tensor([], idtype) data = F.tensor([], idtype)
...@@ -58,7 +61,8 @@ def scipy2tensor(spmat, idtype): ...@@ -58,7 +61,8 @@ def scipy2tensor(spmat, idtype):
spmat = spmat.tocoo() spmat = spmat.tocoo()
row = F.tensor(spmat.row, idtype) row = F.tensor(spmat.row, idtype)
col = F.tensor(spmat.col, idtype) col = F.tensor(spmat.col, idtype)
return SparseAdjTuple('coo', (row, col)) return SparseAdjTuple("coo", (row, col))
def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
"""Function to convert a networkx graph to edge tensors. """Function to convert a networkx graph to edge tensors.
...@@ -82,7 +86,7 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -82,7 +86,7 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
nx_graph = nx_graph.to_directed() nx_graph = nx_graph.to_directed()
# Relabel nodes using consecutive integers # Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted') nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering="sorted")
has_edge_id = edge_id_attr_name is not None has_edge_id = edge_id_attr_name is not None
if has_edge_id: if has_edge_id:
...@@ -92,8 +96,10 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -92,8 +96,10 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = int(attr[edge_id_attr_name]) eid = int(attr[edge_id_attr_name])
if eid < 0 or eid >= nx_graph.number_of_edges(): if eid < 0 or eid >= nx_graph.number_of_edges():
raise DGLError('Expect edge IDs to be a non-negative integer smaller than {:d}, ' raise DGLError(
'got {:d}'.format(num_edges, eid)) "Expect edge IDs to be a non-negative integer smaller than {:d}, "
"got {:d}".format(num_edges, eid)
)
src[eid] = u src[eid] = u
dst[eid] = v dst[eid] = v
else: else:
...@@ -106,7 +112,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -106,7 +112,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
dst = F.tensor(dst, idtype) dst = F.tensor(dst, idtype)
return src, dst return src, dst
SparseAdjTuple = namedtuple('SparseAdjTuple', ['format', 'arrays'])
SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"])
def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
"""Function to convert various types of data to edge tensors and infer """Function to convert various types of data to edge tensors and infer
...@@ -151,33 +159,42 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -151,33 +159,42 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
if isinstance(data, tuple): if isinstance(data, tuple):
if not isinstance(data[0], str): if not isinstance(data[0], str):
# (row, col) format, convert to ('coo', (row, col)) # (row, col) format, convert to ('coo', (row, col))
data = ('coo', data) data = ("coo", data)
data = SparseAdjTuple(*data) data = SparseAdjTuple(*data)
if idtype is None and \ if idtype is None and not (
not (isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])): isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])
):
# preferred default idtype is int64 # preferred default idtype is int64
# if data is tensor and idtype is None, infer the idtype from tensor # if data is tensor and idtype is None, infer the idtype from tensor
idtype = F.int64 idtype = F.int64
checks.check_valid_idtype(idtype) checks.check_valid_idtype(idtype)
if isinstance(data, SparseAdjTuple) and (not all(F.is_tensor(a) for a in data.arrays)): if isinstance(data, SparseAdjTuple) and (
not all(F.is_tensor(a) for a in data.arrays)
):
# (Iterable, Iterable) type data, convert it to (Tensor, Tensor) # (Iterable, Iterable) type data, convert it to (Tensor, Tensor)
if len(data.arrays[0]) == 0: if len(data.arrays[0]) == 0:
# force idtype for empty list # force idtype for empty list
data = SparseAdjTuple(data.format, tuple(F.tensor(a, idtype) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.tensor(a, idtype) for a in data.arrays)
)
else: else:
# convert the iterable to tensor and keep its native data type so we can check # convert the iterable to tensor and keep its native data type so we can check
# its validity later # its validity later
data = SparseAdjTuple(data.format, tuple(F.tensor(a) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.tensor(a) for a in data.arrays)
)
if isinstance(data, SparseAdjTuple): if isinstance(data, SparseAdjTuple):
if idtype is not None: if idtype is not None:
data = SparseAdjTuple(data.format, tuple(F.astype(a, idtype) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.astype(a, idtype) for a in data.arrays)
)
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, list): elif isinstance(data, list):
src, dst = elist2tensor(data, idtype) src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple('coo', (src, dst)) data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily. # We can get scipy matrix's number of rows and columns easily.
...@@ -186,23 +203,31 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -186,23 +203,31 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily. # We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
edge_id_attr_name = kwargs.get('edge_id_attr_name', None) edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite: if bipartite:
top_map = kwargs.get('top_map') top_map = kwargs.get("top_map")
bottom_map = kwargs.get('bottom_map') bottom_map = kwargs.get("bottom_map")
src, dst = networkxbipartite2tensors( src, dst = networkxbipartite2tensors(
data, idtype, top_map=top_map, data,
bottom_map=bottom_map, edge_id_attr_name=edge_id_attr_name) idtype,
top_map=top_map,
bottom_map=bottom_map,
edge_id_attr_name=edge_id_attr_name,
)
else: else:
src, dst = networkx2tensor( src, dst = networkx2tensor(
data, idtype, edge_id_attr_name=edge_id_attr_name) data, idtype, edge_id_attr_name=edge_id_attr_name
data = SparseAdjTuple('coo', (src, dst)) )
data = SparseAdjTuple("coo", (src, dst))
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError("Unsupported graph data type:", type(data))
return data, num_src, num_dst return data, num_src, num_dst
def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_attr_name=None):
def networkxbipartite2tensors(
nx_graph, idtype, top_map, bottom_map, edge_id_attr_name=None
):
"""Function to convert a networkx bipartite to edge tensors. """Function to convert a networkx bipartite to edge tensors.
Parameters Parameters
...@@ -234,15 +259,21 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att ...@@ -234,15 +259,21 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att
dst = [0] * num_edges dst = [0] * num_edges
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
if u not in top_map: if u not in top_map:
raise DGLError('Expect the node {} to have attribute bipartite=0 ' raise DGLError(
'with edge {}'.format(u, (u, v))) "Expect the node {} to have attribute bipartite=0 "
"with edge {}".format(u, (u, v))
)
if v not in bottom_map: if v not in bottom_map:
raise DGLError('Expect the node {} to have attribute bipartite=1 ' raise DGLError(
'with edge {}'.format(v, (u, v))) "Expect the node {} to have attribute bipartite=1 "
"with edge {}".format(v, (u, v))
)
eid = int(attr[edge_id_attr_name]) eid = int(attr[edge_id_attr_name])
if eid < 0 or eid >= nx_graph.number_of_edges(): if eid < 0 or eid >= nx_graph.number_of_edges():
raise DGLError('Expect edge IDs to be a non-negative integer smaller than {:d}, ' raise DGLError(
'got {:d}'.format(num_edges, eid)) "Expect edge IDs to be a non-negative integer smaller than {:d}, "
"got {:d}".format(num_edges, eid)
)
src[eid] = top_map[u] src[eid] = top_map[u]
dst[eid] = bottom_map[v] dst[eid] = bottom_map[v]
else: else:
...@@ -251,17 +282,22 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att ...@@ -251,17 +282,22 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att
for e in nx_graph.edges: for e in nx_graph.edges:
u, v = e[0], e[1] u, v = e[0], e[1]
if u not in top_map: if u not in top_map:
raise DGLError('Expect the node {} to have attribute bipartite=0 ' raise DGLError(
'with edge {}'.format(u, (u, v))) "Expect the node {} to have attribute bipartite=0 "
"with edge {}".format(u, (u, v))
)
if v not in bottom_map: if v not in bottom_map:
raise DGLError('Expect the node {} to have attribute bipartite=1 ' raise DGLError(
'with edge {}'.format(v, (u, v))) "Expect the node {} to have attribute bipartite=1 "
"with edge {}".format(v, (u, v))
)
src.append(top_map[u]) src.append(top_map[u])
dst.append(bottom_map[v]) dst.append(bottom_map[v])
src = F.tensor(src, dtype=idtype) src = F.tensor(src, dtype=idtype)
dst = F.tensor(dst, dtype=idtype) dst = F.tensor(dst, dtype=idtype)
return src, dst return src, dst
def infer_num_nodes(data, bipartite=False): def infer_num_nodes(data, bipartite=False):
"""Function for inferring the number of nodes. """Function for inferring the number of nodes.
...@@ -292,25 +328,35 @@ def infer_num_nodes(data, bipartite=False): ...@@ -292,25 +328,35 @@ def infer_num_nodes(data, bipartite=False):
""" """
if isinstance(data, tuple) and len(data) == 2: if isinstance(data, tuple) and len(data) == 2:
if not isinstance(data[0], str): if not isinstance(data[0], str):
raise TypeError('Expected sparse format as a str, but got %s' % type(data[0])) raise TypeError(
"Expected sparse format as a str, but got %s" % type(data[0])
)
if data[0] == 'coo': if data[0] == "coo":
# ('coo', (src, dst)) format # ('coo', (src, dst)) format
u, v = data[1] u, v = data[1]
nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0 nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0
ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0 ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0
elif data[0] == 'csr': elif data[0] == "csr":
# ('csr', (indptr, indices, eids)) format # ('csr', (indptr, indices, eids)) format
indptr, indices, _ = data[1] indptr, indices, _ = data[1]
nsrc = F.shape(indptr)[0] - 1 nsrc = F.shape(indptr)[0] - 1
ndst = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0 ndst = (
elif data[0] == 'csc': F.as_scalar(F.max(indices, dim=0)) + 1
if len(indices) > 0
else 0
)
elif data[0] == "csc":
# ('csc', (indptr, indices, eids)) format # ('csc', (indptr, indices, eids)) format
indptr, indices, _ = data[1] indptr, indices, _ = data[1]
ndst = F.shape(indptr)[0] - 1 ndst = F.shape(indptr)[0] - 1
nsrc = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0 nsrc = (
F.as_scalar(F.max(indices, dim=0)) + 1
if len(indices) > 0
else 0
)
else: else:
raise ValueError('unknown format %s' % data[0]) raise ValueError("unknown format %s" % data[0])
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
nsrc, ndst = data.shape[0], data.shape[1] nsrc, ndst = data.shape[0], data.shape[1]
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
...@@ -319,7 +365,9 @@ def infer_num_nodes(data, bipartite=False): ...@@ -319,7 +365,9 @@ def infer_num_nodes(data, bipartite=False):
elif not bipartite: elif not bipartite:
nsrc = ndst = data.number_of_nodes() nsrc = ndst = data.number_of_nodes()
else: else:
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0}) nsrc = len(
{n for n, d in data.nodes(data=True) if d["bipartite"] == 0}
)
ndst = data.number_of_nodes() - nsrc ndst = data.number_of_nodes() - nsrc
else: else:
return None return None
...@@ -327,6 +375,7 @@ def infer_num_nodes(data, bipartite=False): ...@@ -327,6 +375,7 @@ def infer_num_nodes(data, bipartite=False):
nsrc = ndst = max(nsrc, ndst) nsrc = ndst = max(nsrc, ndst)
return nsrc, ndst return nsrc, ndst
def to_device(data, device): def to_device(data, device):
"""Transfer the tensor or dictionary of tensors to the given device. """Transfer the tensor or dictionary of tensors to the given device.
......
...@@ -16,14 +16,17 @@ import traceback ...@@ -16,14 +16,17 @@ import traceback
# and the frame (which holds reference to all the object in its temporary scope) # and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback. # holding reference the traceback.
class KeyErrorMessage(str): class KeyErrorMessage(str):
r"""str subclass that returns itself in repr""" r"""str subclass that returns itself in repr"""
def __repr__(self): # pylint: disable=invalid-repr-returned
def __repr__(self): # pylint: disable=invalid-repr-returned
return self return self
class ExceptionWrapper(object): class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads""" r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info=None, where="in background"): def __init__(self, exc_info=None, where="in background"):
# It is important that we don't store exc_info, see # It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ] # NOTE [ Python Traceback Reference Cycle Problem ]
...@@ -38,7 +41,8 @@ class ExceptionWrapper(object): ...@@ -38,7 +41,8 @@ class ExceptionWrapper(object):
# Format a message such as: "Caught ValueError in DataLoader worker # Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback. # process 2. Original Traceback:", followed by the traceback.
msg = "Caught {} {}.\nOriginal {}".format( msg = "Caught {} {}.\nOriginal {}".format(
self.exc_type.__name__, self.where, self.exc_msg) self.exc_type.__name__, self.where, self.exc_msg
)
if self.exc_type == KeyError: if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This # KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python # makes stack traces unreadable. It will not be changed in Python
......
"""Utilities for finding overlap or missing items in arrays.""" """Utilities for finding overlap or missing items in arrays."""
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
class Filter(object): class Filter(object):
...@@ -20,6 +19,7 @@ class Filter(object): ...@@ -20,6 +19,7 @@ class Filter(object):
>>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda'))) >>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda')))
tensor([0,2], device='cuda') tensor([0,2], device='cuda')
""" """
def __init__(self, ids): def __init__(self, ids):
"""Create a new filter from a given set of IDs. This currently is only """Create a new filter from a given set of IDs. This currently is only
implemented for the GPU. implemented for the GPU.
...@@ -30,7 +30,8 @@ class Filter(object): ...@@ -30,7 +30,8 @@ class Filter(object):
The unique set of IDs to keep in the filter. The unique set of IDs to keep in the filter.
""" """
self._filter = _CAPI_DGLFilterCreateFromSet( self._filter = _CAPI_DGLFilterCreateFromSet(
F.zerocopy_to_dgl_ndarray(ids)) F.zerocopy_to_dgl_ndarray(ids)
)
def find_included_indices(self, test): def find_included_indices(self, test):
"""Find the index of the IDs in `test` that are in this filter. """Find the index of the IDs in `test` that are in this filter.
...@@ -45,9 +46,11 @@ class Filter(object): ...@@ -45,9 +46,11 @@ class Filter(object):
IdArray IdArray
The index of IDs in `test` that are also in this filter. The index of IDs in `test` that are also in this filter.
""" """
return F.zerocopy_from_dgl_ndarray( \ return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLFilterFindIncludedIndices( \ _CAPI_DGLFilterFindIncludedIndices(
self._filter, F.zerocopy_to_dgl_ndarray(test))) self._filter, F.zerocopy_to_dgl_ndarray(test)
)
)
def find_excluded_indices(self, test): def find_excluded_indices(self, test):
"""Find the index of the IDs in `test` that are not in this filter. """Find the index of the IDs in `test` that are not in this filter.
...@@ -62,8 +65,11 @@ class Filter(object): ...@@ -62,8 +65,11 @@ class Filter(object):
IdArray IdArray
The index of IDs in `test` that are not in this filter. The index of IDs in `test` that are not in this filter.
""" """
return F.zerocopy_from_dgl_ndarray( \ return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLFilterFindExcludedIndices( \ _CAPI_DGLFilterFindExcludedIndices(
self._filter, F.zerocopy_to_dgl_ndarray(test))) self._filter, F.zerocopy_to_dgl_ndarray(test)
)
)
_init_api("dgl.utils.filter") _init_api("dgl.utils.filter")
"""Internal utilities.""" """Internal utilities."""
from __future__ import absolute_import, division from __future__ import absolute_import, division
from collections.abc import Mapping, Iterable, Sequence
from collections import defaultdict
from functools import wraps
import glob import glob
import os import os
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import wraps
import numpy as np import numpy as np
from ..base import DGLError, dgl_warning, NID, EID
from .. import backend as F from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import EID, NID, DGLError, dgl_warning
def is_listlike(data): def is_listlike(data):
"""Return if the data is a sequence but not a string.""" """Return if the data is a sequence but not a string."""
return isinstance(data, Sequence) and not isinstance(data, str) return isinstance(data, Sequence) and not isinstance(data, str)
class InconsistentDtypeException(DGLError): class InconsistentDtypeException(DGLError):
"""Exception class for inconsistent dtype between graph and tensor""" """Exception class for inconsistent dtype between graph and tensor"""
def __init__(self, msg='', *args, **kwargs): #pylint: disable=W1113
prefix_message = 'DGL now requires the input tensor to have\ def __init__(self, msg="", *args, **kwargs): # pylint: disable=W1113
the same dtype as the graph index\'s dtype(which you can get by g.idype). ' prefix_message = "DGL now requires the input tensor to have\
the same dtype as the graph index's dtype(which you can get by g.idype). "
super().__init__(prefix_message + msg, *args, **kwargs) super().__init__(prefix_message + msg, *args, **kwargs)
class Index(object): class Index(object):
"""Index class that can be easily converted to list/tensor.""" """Index class that can be easily converted to list/tensor."""
def __init__(self, data, dtype="int64"): def __init__(self, data, dtype="int64"):
assert dtype in ['int32', 'int64'] assert dtype in ["int32", "int64"]
self.dtype = dtype self.dtype = dtype
self._initialize_data(data) self._initialize_data(data)
def _initialize_data(self, data): def _initialize_data(self, data):
self._pydata = None # a numpy type data self._pydata = None # a numpy type data
self._user_tensor_data = dict() # dictionary of user tensors self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray self._dgl_tensor_data = None # a dgl ndarray
self._slice_data = None # a slice type data self._slice_data = None # a slice type data
self._dispatch(data) self._dispatch(data)
def __iter__(self): def __iter__(self):
...@@ -61,12 +67,16 @@ class Index(object): ...@@ -61,12 +67,16 @@ class Index(object):
"""Store data based on its type.""" """Store data based on its type."""
if F.is_tensor(data): if F.is_tensor(data):
if F.dtype(data) != F.data_type_dict[self.dtype]: if F.dtype(data) != F.data_type_dict[self.dtype]:
raise InconsistentDtypeException('Index data specified as %s, but got: %s' % raise InconsistentDtypeException(
(self.dtype, "Index data specified as %s, but got: %s"
F.reverse_data_type_dict[F.dtype(data)])) % (self.dtype, F.reverse_data_type_dict[F.dtype(data)])
)
if len(F.shape(data)) > 1: if len(F.shape(data)) > 1:
raise InconsistentDtypeException('Index data must be 1D int32/int64 vector,\ raise InconsistentDtypeException(
but got shape: %s' % str(F.shape(data))) "Index data must be 1D int32/int64 vector,\
but got shape: %s"
% str(F.shape(data))
)
if len(F.shape(data)) == 0: if len(F.shape(data)) == 0:
# a tensor of one int # a tensor of one int
self._dispatch(int(data)) self._dispatch(int(data))
...@@ -74,26 +84,33 @@ class Index(object): ...@@ -74,26 +84,33 @@ class Index(object):
self._user_tensor_data[F.context(data)] = data self._user_tensor_data[F.context(data)] = data
elif isinstance(data, nd.NDArray): elif isinstance(data, nd.NDArray):
if not (data.dtype == self.dtype and len(data.shape) == 1): if not (data.dtype == self.dtype and len(data.shape) == 1):
raise InconsistentDtypeException('Index data must be 1D %s vector, but got: %s' % raise InconsistentDtypeException(
(self.dtype, data.dtype)) "Index data must be 1D %s vector, but got: %s"
% (self.dtype, data.dtype)
)
self._dgl_tensor_data = data self._dgl_tensor_data = data
elif isinstance(data, slice): elif isinstance(data, slice):
# save it in the _pydata temporarily; materialize it if `tonumpy` is called # save it in the _pydata temporarily; materialize it if `tonumpy` is called
assert data.step == 1 or data.step is None, \ assert (
"step for slice type must be 1" data.step == 1 or data.step is None
), "step for slice type must be 1"
self._slice_data = slice(data.start, data.stop) self._slice_data = slice(data.start, data.stop)
else: else:
try: try:
data = np.asarray(data, dtype=self.dtype) data = np.asarray(data, dtype=self.dtype)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
raise DGLError('Error index data: %s' % str(data)) raise DGLError("Error index data: %s" % str(data))
if data.ndim == 0: # scalar array if data.ndim == 0: # scalar array
data = np.expand_dims(data, 0) data = np.expand_dims(data, 0)
elif data.ndim != 1: elif data.ndim != 1:
raise DGLError('Index data must be 1D int64 vector,' raise DGLError(
' but got: %s' % str(data)) "Index data must be 1D int64 vector,"
" but got: %s" % str(data)
)
self._pydata = data self._pydata = data
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._pydata) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(
self._pydata
)
def tonumpy(self): def tonumpy(self):
"""Convert to a numpy ndarray.""" """Convert to a numpy ndarray."""
...@@ -119,7 +136,9 @@ class Index(object): ...@@ -119,7 +136,9 @@ class Index(object):
self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dlpack) self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dlpack)
else: else:
# zero copy from numpy array # zero copy from numpy array
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self.tonumpy()) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(
self.tonumpy()
)
if ctx not in self._user_tensor_data: if ctx not in self._user_tensor_data:
# copy from cpu to another device # copy from cpu to another device
data = next(iter(self._user_tensor_data.values())) data = next(iter(self._user_tensor_data.values()))
...@@ -162,8 +181,10 @@ class Index(object): ...@@ -162,8 +181,10 @@ class Index(object):
self._initialize_data(data) self._initialize_data(data)
else: else:
# pre-0.4.3 # pre-0.4.3
dgl_warning("The object is pickled before 0.4.3. Setting dtype of graph to int64") dgl_warning(
self.dtype = 'int64' "The object is pickled before 0.4.3. Setting dtype of graph to int64"
)
self.dtype = "int64"
self._initialize_data(state) self._initialize_data(state)
def get_items(self, index): def get_items(self, index):
...@@ -193,15 +214,21 @@ class Index(object): ...@@ -193,15 +214,21 @@ class Index(object):
tensor = self.tousertensor() tensor = self.tousertensor()
index = index._slice_data index = index._slice_data
# TODO(Allen): Change F.narrow_row to dgl operation # TODO(Allen): Change F.narrow_row to dgl operation
return Index(F.astype(F.narrow_row(tensor, index.start, index.stop), return Index(
F.data_type_dict[self.dtype]), F.astype(
self.dtype) F.narrow_row(tensor, index.start, index.stop),
F.data_type_dict[self.dtype],
),
self.dtype,
)
else: else:
# both self and index wrap a slice object, then return another # both self and index wrap a slice object, then return another
# Index wrapping a slice # Index wrapping a slice
start = self._slice_data.start start = self._slice_data.start
index = index._slice_data index = index._slice_data
return Index(slice(start + index.start, start + index.stop), self.dtype) return Index(
slice(start + index.start, start + index.stop), self.dtype
)
def set_items(self, index, value): def set_items(self, index, value):
"""Set values at given positions of an Index. Set is not done in place, """Set values at given positions of an Index. Set is not done in place,
...@@ -257,7 +284,8 @@ class Index(object): ...@@ -257,7 +284,8 @@ class Index(object):
tensor = self.tousertensor() tensor = self.tousertensor()
return F.sum(tensor, 0) > 0 return F.sum(tensor, 0) > 0
def toindex(data, dtype='int64'):
def toindex(data, dtype="int64"):
"""Convert the given data to Index object. """Convert the given data to Index object.
Parameters Parameters
...@@ -276,6 +304,7 @@ def toindex(data, dtype='int64'): ...@@ -276,6 +304,7 @@ def toindex(data, dtype='int64'):
""" """
return data if isinstance(data, Index) else Index(data, dtype) return data if isinstance(data, Index) else Index(data, dtype)
def zero_index(size, dtype="int64"): def zero_index(size, dtype="int64"):
"""Create a index with provided size initialized to zero """Create a index with provided size initialized to zero
...@@ -283,8 +312,11 @@ def zero_index(size, dtype="int64"): ...@@ -283,8 +312,11 @@ def zero_index(size, dtype="int64"):
---------- ----------
size: int size: int
""" """
return Index(F.zeros((size,), dtype=F.data_type_dict[dtype], ctx=F.cpu()), return Index(
dtype=dtype) F.zeros((size,), dtype=F.data_type_dict[dtype], ctx=F.cpu()),
dtype=dtype,
)
def set_diff(ar1, ar2): def set_diff(ar1, ar2):
"""Find the set difference of two index arrays. """Find the set difference of two index arrays.
...@@ -309,8 +341,10 @@ def set_diff(ar1, ar2): ...@@ -309,8 +341,10 @@ def set_diff(ar1, ar2):
setdiff = toindex(setdiff) setdiff = toindex(setdiff)
return setdiff return setdiff
class LazyDict(Mapping): class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage.""" """A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys): def __init__(self, fn, keys):
self._fn = fn self._fn = fn
self._keys = keys self._keys = keys
...@@ -332,11 +366,13 @@ class LazyDict(Mapping): ...@@ -332,11 +366,13 @@ class LazyDict(Mapping):
def keys(self): def keys(self):
return self._keys return self._keys
class HybridDict(Mapping): class HybridDict(Mapping):
"""A readonly dictonary that merges several dict-like (python dict, LazyDict). """A readonly dictonary that merges several dict-like (python dict, LazyDict).
If there are duplicate keys, early keys have priority over latter ones. If there are duplicate keys, early keys have priority over latter ones.
""" """
def __init__(self, *dict_like_list): def __init__(self, *dict_like_list):
self._dict_like_list = dict_like_list self._dict_like_list = dict_like_list
self._keys = set() self._keys = set()
...@@ -361,8 +397,10 @@ class HybridDict(Mapping): ...@@ -361,8 +397,10 @@ class HybridDict(Mapping):
def __len__(self): def __len__(self):
return len(self.keys()) return len(self.keys())
class ReadOnlyDict(Mapping): class ReadOnlyDict(Mapping):
"""A readonly dictionary wrapper.""" """A readonly dictionary wrapper."""
def __init__(self, dict_like): def __init__(self, dict_like):
self._dict_like = dict_like self._dict_like = dict_like
...@@ -381,6 +419,7 @@ class ReadOnlyDict(Mapping): ...@@ -381,6 +419,7 @@ class ReadOnlyDict(Mapping):
def __len__(self): def __len__(self):
return len(self._dict_like) return len(self._dict_like)
def build_relabel_map(x, is_sorted=False): def build_relabel_map(x, is_sorted=False):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
...@@ -423,6 +462,7 @@ def build_relabel_map(x, is_sorted=False): ...@@ -423,6 +462,7 @@ def build_relabel_map(x, is_sorted=False):
old_to_new = F.scatter_row(old_to_new, unique_x, F.arange(0, len(unique_x))) old_to_new = F.scatter_row(old_to_new, unique_x, F.arange(0, len(unique_x)))
return unique_x, old_to_new return unique_x, old_to_new
def build_relabel_dict(x): def build_relabel_dict(x):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
...@@ -443,6 +483,7 @@ def build_relabel_dict(x): ...@@ -443,6 +483,7 @@ def build_relabel_dict(x):
relabel_dict[v] = i relabel_dict[v] = i
return relabel_dict return relabel_dict
class CtxCachedObject(object): class CtxCachedObject(object):
"""A wrapper to cache object generated by different context. """A wrapper to cache object generated by different context.
...@@ -453,6 +494,7 @@ class CtxCachedObject(object): ...@@ -453,6 +494,7 @@ class CtxCachedObject(object):
generator : callable generator : callable
A callable function that can create the object given ctx as the only argument. A callable function that can create the object given ctx as the only argument.
""" """
def __init__(self, generator): def __init__(self, generator):
self._generator = generator self._generator = generator
self._ctx_dict = {} self._ctx_dict = {}
...@@ -462,6 +504,7 @@ class CtxCachedObject(object): ...@@ -462,6 +504,7 @@ class CtxCachedObject(object):
self._ctx_dict[ctx] = self._generator(ctx) self._ctx_dict[ctx] = self._generator(ctx)
return self._ctx_dict[ctx] return self._ctx_dict[ctx]
def cached_member(cache, prefix): def cached_member(cache, prefix):
"""A member function decorator to memorize the result. """A member function decorator to memorize the result.
...@@ -476,24 +519,30 @@ def cached_member(cache, prefix): ...@@ -476,24 +519,30 @@ def cached_member(cache, prefix):
prefix : str prefix : str
The key prefix to save the result of the function. The key prefix to save the result of the function.
""" """
def _creator(func): def _creator(func):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
dic = getattr(self, cache) dic = getattr(self, cache)
key = '%s-%s-%s' % ( key = "%s-%s-%s" % (
prefix, prefix,
'-'.join([str(a) for a in args]), "-".join([str(a) for a in args]),
'-'.join([str(k) + ':' + str(v) for k, v in kwargs.items()])) "-".join([str(k) + ":" + str(v) for k, v in kwargs.items()]),
)
if key not in dic: if key not in dic:
dic[key] = func(self, *args, **kwargs) dic[key] = func(self, *args, **kwargs)
return dic[key] return dic[key]
return wrapper return wrapper
return _creator return _creator
def is_dict_like(obj): def is_dict_like(obj):
"""Return true if the object can be treated as a dictionary.""" """Return true if the object can be treated as a dictionary."""
return isinstance(obj, Mapping) return isinstance(obj, Mapping)
def reorder(dict_like, index): def reorder(dict_like, index):
"""Reorder each column in the dict according to the index. """Reorder each column in the dict according to the index.
...@@ -510,6 +559,7 @@ def reorder(dict_like, index): ...@@ -510,6 +559,7 @@ def reorder(dict_like, index):
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
def reorder_index(idx, order): def reorder_index(idx, order):
"""Reorder the idx according to the given order """Reorder the idx according to the given order
...@@ -525,26 +575,30 @@ def reorder_index(idx, order): ...@@ -525,26 +575,30 @@ def reorder_index(idx, order):
new_idx = F.gather_row(idx, order) new_idx = F.gather_row(idx, order)
return toindex(new_idx) return toindex(new_idx)
def is_iterable(obj): def is_iterable(obj):
"""Return true if the object is an iterable.""" """Return true if the object is an iterable."""
return isinstance(obj, Iterable) return isinstance(obj, Iterable)
def to_dgl_context(ctx): def to_dgl_context(ctx):
"""Convert a backend context to DGLContext""" """Convert a backend context to DGLContext"""
device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)] device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
device_id = F.device_id(ctx) device_id = F.device_id(ctx)
return nd.DGLContext(device_type, device_id) return nd.DGLContext(device_type, device_id)
def to_nbits_int(tensor, nbits): def to_nbits_int(tensor, nbits):
"""Change the dtype of integer tensor """Change the dtype of integer tensor
The dtype of returned tensor uses nbits, nbits can only be 32 or 64 The dtype of returned tensor uses nbits, nbits can only be 32 or 64
""" """
assert(nbits in (32, 64)), "nbits can either be 32 or 64" assert nbits in (32, 64), "nbits can either be 32 or 64"
if nbits == 32: if nbits == 32:
return F.astype(tensor, F.int32) return F.astype(tensor, F.int32)
else: else:
return F.astype(tensor, F.int64) return F.astype(tensor, F.int64)
def make_invmap(array, use_numpy=True): def make_invmap(array, use_numpy=True):
"""Find the unique elements of the array and return another array with indices """Find the unique elements of the array and return another array with indices
to the array of unique elements.""" to the array of unique elements."""
...@@ -556,6 +610,7 @@ def make_invmap(array, use_numpy=True): ...@@ -556,6 +610,7 @@ def make_invmap(array, use_numpy=True):
remapped = np.asarray([invmap[x] for x in array]) remapped = np.asarray([invmap[x] for x in array])
return uniques, invmap, remapped return uniques, invmap, remapped
def expand_as_pair(input_, g=None): def expand_as_pair(input_, g=None):
"""Return a pair of same element if the input is not a pair. """Return a pair of same element if the input is not a pair.
...@@ -581,13 +636,15 @@ def expand_as_pair(input_, g=None): ...@@ -581,13 +636,15 @@ def expand_as_pair(input_, g=None):
if isinstance(input_, Mapping): if isinstance(input_, Mapping):
input_dst = { input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()} for k, v in input_.items()
}
else: else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes()) input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst return input_, input_dst
else: else:
return input_, input_ return input_, input_
def check_eq_shape(input_): def check_eq_shape(input_):
"""If input_ is a pair of features, check if the feature shape of source """If input_ is a pair of features, check if the feature shape of source
nodes is equal to the feature shape of destination nodes. nodes is equal to the feature shape of destination nodes.
...@@ -596,9 +653,14 @@ def check_eq_shape(input_): ...@@ -596,9 +653,14 @@ def check_eq_shape(input_):
src_feat_shape = tuple(F.shape(srcdata))[1:] src_feat_shape = tuple(F.shape(srcdata))[1:]
dst_feat_shape = tuple(F.shape(dstdata))[1:] dst_feat_shape = tuple(F.shape(dstdata))[1:]
if src_feat_shape != dst_feat_shape: if src_feat_shape != dst_feat_shape:
raise DGLError("The feature shape of source nodes: {} \ raise DGLError(
"The feature shape of source nodes: {} \
should be equal to the feature shape of destination \ should be equal to the feature shape of destination \
nodes: {}.".format(src_feat_shape, dst_feat_shape)) nodes: {}.".format(
src_feat_shape, dst_feat_shape
)
)
def retry_method_with_fix(fix_method): def retry_method_with_fix(fix_method):
"""Decorator that executes a fix method before retrying again when the decorated method """Decorator that executes a fix method before retrying again when the decorated method
...@@ -617,6 +679,7 @@ def retry_method_with_fix(fix_method): ...@@ -617,6 +679,7 @@ def retry_method_with_fix(fix_method):
The fix method to execute. It should not accept any arguments. Its return values are The fix method to execute. It should not accept any arguments. Its return values are
ignored. ignored.
""" """
def _creator(func): def _creator(func):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
...@@ -628,8 +691,10 @@ def retry_method_with_fix(fix_method): ...@@ -628,8 +691,10 @@ def retry_method_with_fix(fix_method):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return wrapper return wrapper
return _creator return _creator
def group_as_dict(pairs): def group_as_dict(pairs):
"""Combines a list of key-value pairs to a dictionary of keys and value lists. """Combines a list of key-value pairs to a dictionary of keys and value lists.
...@@ -650,6 +715,7 @@ def group_as_dict(pairs): ...@@ -650,6 +715,7 @@ def group_as_dict(pairs):
dic[key].append(value) dic[key].append(value)
return dic return dic
class FlattenedDict(object): class FlattenedDict(object):
"""Iterates over each item in a dictionary of groups. """Iterates over each item in a dictionary of groups.
...@@ -668,6 +734,7 @@ class FlattenedDict(object): ...@@ -668,6 +734,7 @@ class FlattenedDict(object):
>>> len(groups) >>> len(groups)
6 6
""" """
def __init__(self, groups): def __init__(self, groups):
self._groups = groups self._groups = groups
group_sizes = {k: len(v) for k, v in groups.items()} group_sizes = {k: len(v) for k, v in groups.items()}
...@@ -675,9 +742,11 @@ class FlattenedDict(object): ...@@ -675,9 +742,11 @@ class FlattenedDict(object):
self._group_offsets = np.insert(np.cumsum(self._group_sizes), 0, 0) self._group_offsets = np.insert(np.cumsum(self._group_sizes), 0, 0)
# TODO: this is faster (37s -> 21s per epoch compared to searchsorted in GCMC) but takes # TODO: this is faster (37s -> 21s per epoch compared to searchsorted in GCMC) but takes
# O(E) memory. # O(E) memory.
self._idx_to_group = np.zeros(self._group_offsets[-1], dtype='int32') self._idx_to_group = np.zeros(self._group_offsets[-1], dtype="int32")
for i in range(len(self._groups)): for i in range(len(self._groups)):
self._idx_to_group[self._group_offsets[i]:self._group_offsets[i + 1]] = i self._idx_to_group[
self._group_offsets[i] : self._group_offsets[i + 1]
] = i
def __len__(self): def __len__(self):
"""Return the total number of items.""" """Return the total number of items."""
...@@ -697,11 +766,12 @@ class FlattenedDict(object): ...@@ -697,11 +766,12 @@ class FlattenedDict(object):
g = self._groups[k] g = self._groups[k]
return k, g[j] return k, g[j]
def maybe_flatten_dict(data): def maybe_flatten_dict(data):
"""Return a FlattenedDict if the input is a Mapping, or the data itself otherwise. """Return a FlattenedDict if the input is a Mapping, or the data itself otherwise."""
"""
return FlattenedDict(data) if isinstance(data, Mapping) else data return FlattenedDict(data) if isinstance(data, Mapping) else data
def compensate(ids, origin_ids): def compensate(ids, origin_ids):
"""computing the compensate set of ids from origin_ids """computing the compensate set of ids from origin_ids
...@@ -716,16 +786,19 @@ def compensate(ids, origin_ids): ...@@ -716,16 +786,19 @@ def compensate(ids, origin_ids):
th.Tensor([1, 5]) th.Tensor([1, 5])
""" """
# trick here, eid_0 or nid_0 can be 0. # trick here, eid_0 or nid_0 can be 0.
mask = F.scatter_row(origin_ids, mask = F.scatter_row(
F.copy_to(F.tensor(0, dtype=F.int64), origin_ids,
F.context(origin_ids)), F.copy_to(F.tensor(0, dtype=F.int64), F.context(origin_ids)),
F.copy_to(F.tensor(1, dtype=F.dtype(origin_ids)), F.copy_to(
F.context(origin_ids))) F.tensor(1, dtype=F.dtype(origin_ids)), F.context(origin_ids)
mask = F.scatter_row(mask, ),
ids, )
F.full_1d(len(ids), 0, F.dtype(ids), F.context(ids))) mask = F.scatter_row(
mask, ids, F.full_1d(len(ids), 0, F.dtype(ids), F.context(ids))
)
return F.tensor(F.nonzero_1d(mask), dtype=F.dtype(ids)) return F.tensor(F.nonzero_1d(mask), dtype=F.dtype(ids))
def relabel(x): def relabel(x):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
...@@ -761,10 +834,12 @@ def relabel(x): ...@@ -761,10 +834,12 @@ def relabel(x):
ctx = F.context(x) ctx = F.context(x)
dtype = F.dtype(x) dtype = F.dtype(x)
old_to_new = F.zeros((map_len,), dtype=dtype, ctx=ctx) old_to_new = F.zeros((map_len,), dtype=dtype, ctx=ctx)
old_to_new = F.scatter_row(old_to_new, unique_x, old_to_new = F.scatter_row(
F.copy_to(F.arange(0, len(unique_x), dtype), ctx)) old_to_new, unique_x, F.copy_to(F.arange(0, len(unique_x), dtype), ctx)
)
return unique_x, old_to_new return unique_x, old_to_new
def extract_node_subframes(graph, nodes_or_device, store_ids=True): def extract_node_subframes(graph, nodes_or_device, store_ids=True):
"""Extract node features of the given nodes from :attr:`graph` """Extract node features of the given nodes from :attr:`graph`
and return them in frames on the given device. and return them in frames on the given device.
...@@ -801,10 +876,11 @@ def extract_node_subframes(graph, nodes_or_device, store_ids=True): ...@@ -801,10 +876,11 @@ def extract_node_subframes(graph, nodes_or_device, store_ids=True):
if store_ids: if store_ids:
subf[NID] = ind_nodes subf[NID] = ind_nodes
node_frames.append(subf) node_frames.append(subf)
else: # device object else: # device object
node_frames = [nf.to(nodes_or_device) for nf in graph._node_frames] node_frames = [nf.to(nodes_or_device) for nf in graph._node_frames]
return node_frames return node_frames
def extract_node_subframes_for_block(graph, srcnodes, dstnodes): def extract_node_subframes_for_block(graph, srcnodes, dstnodes):
"""Extract the input node features and output node features of the given nodes from """Extract the input node features and output node features of the given nodes from
:attr:`graph` and return them in frames ready for a block. :attr:`graph` and return them in frames ready for a block.
...@@ -841,6 +917,7 @@ def extract_node_subframes_for_block(graph, srcnodes, dstnodes): ...@@ -841,6 +917,7 @@ def extract_node_subframes_for_block(graph, srcnodes, dstnodes):
node_frames.append(subf) node_frames.append(subf)
return node_frames return node_frames
def extract_edge_subframes(graph, edges_or_device, store_ids=True): def extract_edge_subframes(graph, edges_or_device, store_ids=True):
"""Extract edge features of the given edges from :attr:`graph` """Extract edge features of the given edges from :attr:`graph`
and return them in frames. and return them in frames.
...@@ -877,10 +954,11 @@ def extract_edge_subframes(graph, edges_or_device, store_ids=True): ...@@ -877,10 +954,11 @@ def extract_edge_subframes(graph, edges_or_device, store_ids=True):
if store_ids: if store_ids:
subf[EID] = ind_edges subf[EID] = ind_edges
edge_frames.append(subf) edge_frames.append(subf)
else: # device object else: # device object
edge_frames = [nf.to(device) for nf in graph._edge_frames] edge_frames = [nf.to(device) for nf in graph._edge_frames]
return edge_frames return edge_frames
def set_new_frames(graph, *, node_frames=None, edge_frames=None): def set_new_frames(graph, *, node_frames=None, edge_frames=None):
"""Set the node and edge frames of a given graph to new ones. """Set the node and edge frames of a given graph to new ones.
...@@ -898,14 +976,17 @@ def set_new_frames(graph, *, node_frames=None, edge_frames=None): ...@@ -898,14 +976,17 @@ def set_new_frames(graph, *, node_frames=None, edge_frames=None):
Default is None, where the edge frames are not updated. Default is None, where the edge frames are not updated.
""" """
if node_frames is not None: if node_frames is not None:
assert len(node_frames) == len(graph.ntypes), \ assert len(node_frames) == len(
"[BUG] number of node frames different from number of node types" graph.ntypes
), "[BUG] number of node frames different from number of node types"
graph._node_frames = node_frames graph._node_frames = node_frames
if edge_frames is not None: if edge_frames is not None:
assert len(edge_frames) == len(graph.etypes), \ assert len(edge_frames) == len(
"[BUG] number of edge frames different from number of edge types" graph.etypes
), "[BUG] number of edge frames different from number of edge types"
graph._edge_frames = edge_frames graph._edge_frames = edge_frames
def set_num_threads(num_threads): def set_num_threads(num_threads):
"""Set the number of OMP threads in the process. """Set the number of OMP threads in the process.
...@@ -916,18 +997,20 @@ def set_num_threads(num_threads): ...@@ -916,18 +997,20 @@ def set_num_threads(num_threads):
""" """
_CAPI_DGLSetOMPThreads(num_threads) _CAPI_DGLSetOMPThreads(num_threads)
def get_num_threads(): def get_num_threads():
"""Get the number of OMP threads in the process""" """Get the number of OMP threads in the process"""
return _CAPI_DGLGetOMPThreads() return _CAPI_DGLGetOMPThreads()
def get_numa_nodes_cores(): def get_numa_nodes_cores():
""" Returns numa nodes info, format: """Returns numa nodes info, format:
{<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>, ...]), ...], ...} {<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>, ...]), ...], ...}
E.g.: {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]} E.g.: {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]}
If not available, returns {} If not available, returns {}
""" """
numa_node_paths = glob.glob('/sys/devices/system/node/node[0-9]*') numa_node_paths = glob.glob("/sys/devices/system/node/node[0-9]*")
if not numa_node_paths: if not numa_node_paths:
return {} return {}
...@@ -938,32 +1021,40 @@ def get_numa_nodes_cores(): ...@@ -938,32 +1021,40 @@ def get_numa_nodes_cores():
numa_node_id = int(os.path.basename(node_path)[4:]) numa_node_id = int(os.path.basename(node_path)[4:])
thread_siblings = {} thread_siblings = {}
for cpu_dir in glob.glob(os.path.join(node_path, 'cpu[0-9]*')): for cpu_dir in glob.glob(os.path.join(node_path, "cpu[0-9]*")):
cpu_id = int(os.path.basename(cpu_dir)[3:]) cpu_id = int(os.path.basename(cpu_dir)[3:])
with open(os.path.join(cpu_dir, 'topology', 'core_id')) as core_id_file: with open(
os.path.join(cpu_dir, "topology", "core_id")
) as core_id_file:
core_id = int(core_id_file.read().strip()) core_id = int(core_id_file.read().strip())
if core_id in thread_siblings: if core_id in thread_siblings:
thread_siblings[core_id].append(cpu_id) thread_siblings[core_id].append(cpu_id)
else: else:
thread_siblings[core_id] = [cpu_id] thread_siblings[core_id] = [cpu_id]
nodes[numa_node_id] = sorted([(k, sorted(v)) for k, v in thread_siblings.items()]) nodes[numa_node_id] = sorted(
[(k, sorted(v)) for k, v in thread_siblings.items()]
)
except (OSError, ValueError, IndexError, IOError): except (OSError, ValueError, IndexError, IOError):
dgl_warning('Failed to read NUMA info') dgl_warning("Failed to read NUMA info")
return {} return {}
return nodes return nodes
def alias_func(func): def alias_func(func):
"""Return an alias function with proper docstring.""" """Return an alias function with proper docstring."""
@wraps(func) @wraps(func)
def _fn(*args, **kwargs): def _fn(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
_fn.__doc__ = """Alias of :func:`dgl.{}`.""".format(func.__name__) _fn.__doc__ = """Alias of :func:`dgl.{}`.""".format(func.__name__)
return _fn return _fn
def apply_each(data, fn, *args, **kwargs): def apply_each(data, fn, *args, **kwargs):
"""Apply a function to every element in a container. """Apply a function to every element in a container.
...@@ -1000,6 +1091,7 @@ def apply_each(data, fn, *args, **kwargs): ...@@ -1000,6 +1091,7 @@ def apply_each(data, fn, *args, **kwargs):
else: else:
return fn(data, *args, **kwargs) return fn(data, *args, **kwargs)
def recursive_apply(data, fn, *args, **kwargs): def recursive_apply(data, fn, *args, **kwargs):
"""Recursively apply a function to every element in a container. """Recursively apply a function to every element in a container.
...@@ -1033,12 +1125,15 @@ def recursive_apply(data, fn, *args, **kwargs): ...@@ -1033,12 +1125,15 @@ def recursive_apply(data, fn, *args, **kwargs):
>>> assert all((v >= 0).all() for v in h.values()) >>> assert all((v >= 0).all() for v in h.values())
""" """
if isinstance(data, Mapping): if isinstance(data, Mapping):
return {k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()} return {
k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()
}
elif is_listlike(data): elif is_listlike(data):
return [recursive_apply(v, fn, *args, **kwargs) for v in data] return [recursive_apply(v, fn, *args, **kwargs) for v in data]
else: else:
return fn(data, *args, **kwargs) return fn(data, *args, **kwargs)
def recursive_apply_pair(data1, data2, fn, *args, **kwargs): def recursive_apply_pair(data1, data2, fn, *args, **kwargs):
"""Recursively apply a function to every pair of elements in two containers with the """Recursively apply a function to every pair of elements in two containers with the
same nested structure. same nested structure.
...@@ -1046,12 +1141,17 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs): ...@@ -1046,12 +1141,17 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs):
if isinstance(data1, Mapping) and isinstance(data2, Mapping): if isinstance(data1, Mapping) and isinstance(data2, Mapping):
return { return {
k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs) k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs)
for k in data1.keys()} for k in data1.keys()
}
elif is_listlike(data1) and is_listlike(data2): elif is_listlike(data1) and is_listlike(data2):
return [recursive_apply_pair(x, y, fn, *args, **kwargs) for x, y in zip(data1, data2)] return [
recursive_apply_pair(x, y, fn, *args, **kwargs)
for x, y in zip(data1, data2)
]
else: else:
return fn(data1, data2, *args, **kwargs) return fn(data1, data2, *args, **kwargs)
def context_of(data): def context_of(data):
"""Return the device of the data which can be either a tensor or a list/dict of tensors.""" """Return the device of the data which can be either a tensor or a list/dict of tensors."""
if isinstance(data, Mapping): if isinstance(data, Mapping):
...@@ -1061,8 +1161,12 @@ def context_of(data): ...@@ -1061,8 +1161,12 @@ def context_of(data):
else: else:
return F.context(data) return F.context(data)
def dtype_of(data): def dtype_of(data):
"""Return the dtype of the data which can be either a tensor or a dict of tensors.""" """Return the dtype of the data which can be either a tensor or a dict of tensors."""
return F.dtype(next(iter(data.values())) if isinstance(data, Mapping) else data) return F.dtype(
next(iter(data.values())) if isinstance(data, Mapping) else data
)
_init_api("dgl.utils.internal") _init_api("dgl.utils.internal")
"""Utility functions related to pinned memory tensors.""" """Utility functions related to pinned memory tensors."""
from ..base import DGLError
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import DGLError
def pin_memory_inplace(tensor): def pin_memory_inplace(tensor):
"""Register the tensor into pinned memory in-place (i.e. without copying). """Register the tensor into pinned memory in-place (i.e. without copying).
...@@ -19,9 +20,11 @@ def pin_memory_inplace(tensor): ...@@ -19,9 +20,11 @@ def pin_memory_inplace(tensor):
The dgl.ndarray object that holds the pinning status and shares the same The dgl.ndarray object that holds the pinning status and shares the same
underlying data with the tensor. underlying data with the tensor.
""" """
if F.backend_name in ['mxnet', 'tensorflow']: if F.backend_name in ["mxnet", "tensorflow"]:
raise DGLError("The {} backend does not support pinning " \ raise DGLError(
"tensors in-place.".format(F.backend_name)) "The {} backend does not support pinning "
"tensors in-place.".format(F.backend_name)
)
# needs to be writable to allow in-place modification # needs to be writable to allow in-place modification
try: try:
...@@ -31,6 +34,7 @@ def pin_memory_inplace(tensor): ...@@ -31,6 +34,7 @@ def pin_memory_inplace(tensor):
except Exception as e: except Exception as e:
raise DGLError("Failed to pin memory in-place due to: {}".format(e)) raise DGLError("Failed to pin memory in-place due to: {}".format(e))
def gather_pinned_tensor_rows(tensor, rows): def gather_pinned_tensor_rows(tensor, rows):
"""Directly gather rows from a CPU tensor given an indices array on CUDA devices, """Directly gather rows from a CPU tensor given an indices array on CUDA devices,
and returns the result on the same CUDA device without copying. and returns the result on the same CUDA device without copying.
...@@ -47,7 +51,10 @@ def gather_pinned_tensor_rows(tensor, rows): ...@@ -47,7 +51,10 @@ def gather_pinned_tensor_rows(tensor, rows):
Tensor Tensor
The result with the same device as :attr:`rows`. The result with the same device as :attr:`rows`.
""" """
return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))) return F.from_dgl_nd(
_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))
)
def scatter_pinned_tensor_rows(dest, rows, source): def scatter_pinned_tensor_rows(dest, rows, source):
"""Directly scatter rows from a GPU tensor given an indices array on CUDA devices, """Directly scatter rows from a GPU tensor given an indices array on CUDA devices,
...@@ -62,8 +69,9 @@ def scatter_pinned_tensor_rows(dest, rows, source): ...@@ -62,8 +69,9 @@ def scatter_pinned_tensor_rows(dest, rows, source):
source : Tensor source : Tensor
The tensor on the GPU to scatter rows from. The tensor on the GPU to scatter rows from.
""" """
_CAPI_DGLIndexScatterGPUToCPU(F.to_dgl_nd(dest), F.to_dgl_nd(rows), _CAPI_DGLIndexScatterGPUToCPU(
F.to_dgl_nd(source)) F.to_dgl_nd(dest), F.to_dgl_nd(rows), F.to_dgl_nd(source)
)
_init_api("dgl.ndarray.uvm", __name__) _init_api("dgl.ndarray.uvm", __name__)
"""Shared memory utilities. """Shared memory utilities.
For compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the For compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the
content has been moved to ``dgl.ndarray`` module. content has been moved to ``dgl.ndarray`` module.
""" """
from ..ndarray import get_shared_mem_array, create_shared_mem_array # pylint: disable=unused-import from ..ndarray import ( # pylint: disable=unused-import
create_shared_mem_array,
get_shared_mem_array,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment