Unverified Commit 98325b10 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4691)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent c24e285a
"""Negative sampling APIs""" """Negative sampling APIs"""
from numpy.polynomial import polynomial from numpy.polynomial import polynomial
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .._ffi.function import _init_api
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
__all__ = [ __all__ = ["global_uniform_negative_sampling"]
'global_uniform_negative_sampling']
def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invalid-name def _calc_redundancy(
k_hat, num_edges, num_pairs, r=3
): # pylint: disable=invalid-name
# pylint: disable=invalid-name # pylint: disable=invalid-name
# Calculates the number of samples required based on a lower-bound # Calculates the number of samples required based on a lower-bound
# of the expected number of negative samples, based on N draws from # of the expected number of negative samples, based on N draws from
...@@ -24,18 +27,24 @@ def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invali ...@@ -24,18 +27,24 @@ def _calc_redundancy(k_hat, num_edges, num_pairs, r=3): # pylint: disable=invali
p_m = num_edges / num_pairs p_m = num_edges / num_pairs
p_k = 1 - p_m p_k = 1 - p_m
a = p_k ** 2 a = p_k**2
b = -p_k * (2 * k_hat + r ** 2 * p_m) b = -p_k * (2 * k_hat + r**2 * p_m)
c = k_hat ** 2 c = k_hat**2
poly = polynomial.Polynomial([c, b, a]) poly = polynomial.Polynomial([c, b, a])
N = poly.roots()[-1] N = poly.roots()[-1]
redundancy = N / k_hat - 1. redundancy = N / k_hat - 1.0
return redundancy return redundancy
def global_uniform_negative_sampling( def global_uniform_negative_sampling(
g, num_samples, exclude_self_loops=True, replace=False, etype=None, g,
redundancy=None): num_samples,
exclude_self_loops=True,
replace=False,
etype=None,
redundancy=None,
):
"""Performs negative sampling, which generate source-destination pairs such that """Performs negative sampling, which generate source-destination pairs such that
edges with the given type do not exist. edges with the given type do not exist.
...@@ -95,13 +104,24 @@ def global_uniform_negative_sampling( ...@@ -95,13 +104,24 @@ def global_uniform_negative_sampling(
exclude_self_loops = exclude_self_loops and (utype == vtype) exclude_self_loops = exclude_self_loops and (utype == vtype)
redundancy = _calc_redundancy( redundancy = _calc_redundancy(
num_samples, g.num_edges(etype), g.num_nodes(utype) * g.num_nodes(vtype)) num_samples, g.num_edges(etype), g.num_nodes(utype) * g.num_nodes(vtype)
)
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
src, dst = _CAPI_DGLGlobalUniformNegativeSampling( src, dst = _CAPI_DGLGlobalUniformNegativeSampling(
g._graph, etype_id, num_samples, 3, exclude_self_loops, replace, redundancy) g._graph,
etype_id,
num_samples,
3,
exclude_self_loops,
replace,
redundancy,
)
return F.from_dgl_nd(src), F.from_dgl_nd(dst) return F.from_dgl_nd(src), F.from_dgl_nd(dst)
DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func( DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func(
global_uniform_negative_sampling) global_uniform_negative_sampling
)
_init_api('dgl.sampling.negative', __name__) _init_api("dgl.sampling.negative", __name__)
"""Node2vec random walk""" """Node2vec random walk"""
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
from .. import utils from .. import utils
from .._ffi.function import _init_api
# pylint: disable=invalid-name # pylint: disable=invalid-name
__all__ = ['node2vec_random_walk'] __all__ = ["node2vec_random_walk"]
def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=False): def node2vec_random_walk(
g, nodes, p, q, walk_length, prob=None, return_eids=False
):
""" """
Generate random walk traces from an array of starting nodes based on the node2vec model. Generate random walk traces from an array of starting nodes based on the node2vec model.
Paper: `node2vec: Scalable Feature Learning for Networks Paper: `node2vec: Scalable Feature Learning for Networks
...@@ -82,14 +85,16 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal ...@@ -82,14 +85,16 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal
assert g.device == F.cpu(), "Graph must be on CPU." assert g.device == F.cpu(), "Graph must be on CPU."
gidx = g._graph gidx = g._graph
nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, 'nodes')) nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, "nodes"))
if prob is None: if prob is None:
prob_nd = nd.array([], ctx=nodes.ctx) prob_nd = nd.array([], ctx=nodes.ctx)
else: else:
prob_nd = F.to_dgl_nd(g.edata[prob]) prob_nd = F.to_dgl_nd(g.edata[prob])
traces, eids = _CAPI_DGLSamplingNode2vec(gidx, nodes, p, q, walk_length, prob_nd) traces, eids = _CAPI_DGLSamplingNode2vec(
gidx, nodes, p, q, walk_length, prob_nd
)
traces = F.from_dgl_nd(traces) traces = F.from_dgl_nd(traces)
eids = F.from_dgl_nd(eids) eids = F.from_dgl_nd(eids)
...@@ -97,4 +102,4 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal ...@@ -97,4 +102,4 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal
return (traces, eids) if return_eids else traces return (traces, eids) if return_eids else traces
_init_api('dgl.sampling.randomwalks', __name__) _init_api("dgl.sampling.randomwalks", __name__)
"""PinSAGE sampler & related functions and classes""" """PinSAGE sampler & related functions and classes"""
import numpy as np import numpy as np
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .. import convert from .. import convert, utils
from .._ffi.function import _init_api
from .randomwalks import random_walk from .randomwalks import random_walk
from .. import utils
def _select_pinsage_neighbors(src, dst, num_samples_per_node, k): def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
"""Determine the neighbors for PinSAGE algorithm from the given random walk traces. """Determine the neighbors for PinSAGE algorithm from the given random walk traces.
...@@ -16,12 +16,15 @@ def _select_pinsage_neighbors(src, dst, num_samples_per_node, k): ...@@ -16,12 +16,15 @@ def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
""" """
src = F.to_dgl_nd(src) src = F.to_dgl_nd(src)
dst = F.to_dgl_nd(dst) dst = F.to_dgl_nd(dst)
src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(src, dst, num_samples_per_node, k) src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(
src, dst, num_samples_per_node, k
)
src = F.from_dgl_nd(src) src = F.from_dgl_nd(src)
dst = F.from_dgl_nd(dst) dst = F.from_dgl_nd(dst)
counts = F.from_dgl_nd(counts) counts = F.from_dgl_nd(counts)
return (src, dst, counts) return (src, dst, counts)
class RandomWalkNeighborSampler(object): class RandomWalkNeighborSampler(object):
"""PinSage-like neighbor sampler extended to any heterogeneous graphs. """PinSage-like neighbor sampler extended to any heterogeneous graphs.
...@@ -72,8 +75,17 @@ class RandomWalkNeighborSampler(object): ...@@ -72,8 +75,17 @@ class RandomWalkNeighborSampler(object):
-------- --------
See examples in :any:`PinSAGESampler`. See examples in :any:`PinSAGESampler`.
""" """
def __init__(self, G, num_traversals, termination_prob,
num_random_walks, num_neighbors, metapath=None, weight_column='weights'): def __init__(
self,
G,
num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
metapath=None,
weight_column="weights",
):
self.G = G self.G = G
self.weight_column = weight_column self.weight_column = weight_column
self.num_random_walks = num_random_walks self.num_random_walks = num_random_walks
...@@ -82,19 +94,25 @@ class RandomWalkNeighborSampler(object): ...@@ -82,19 +94,25 @@ class RandomWalkNeighborSampler(object):
if metapath is None: if metapath is None:
if len(G.ntypes) > 1 or len(G.etypes) > 1: if len(G.ntypes) > 1 or len(G.etypes) > 1:
raise ValueError('Metapath must be specified if the graph is homogeneous.') raise ValueError(
"Metapath must be specified if the graph is homogeneous."
)
metapath = [G.canonical_etypes[0]] metapath = [G.canonical_etypes[0]]
start_ntype = G.to_canonical_etype(metapath[0])[0] start_ntype = G.to_canonical_etype(metapath[0])[0]
end_ntype = G.to_canonical_etype(metapath[-1])[-1] end_ntype = G.to_canonical_etype(metapath[-1])[-1]
if start_ntype != end_ntype: if start_ntype != end_ntype:
raise ValueError('The metapath must start and end at the same node type.') raise ValueError(
"The metapath must start and end at the same node type."
)
self.ntype = start_ntype self.ntype = start_ntype
self.metapath_hops = len(metapath) self.metapath_hops = len(metapath)
self.metapath = metapath self.metapath = metapath
self.full_metapath = metapath * num_traversals self.full_metapath = metapath * num_traversals
restart_prob = np.zeros(self.metapath_hops * num_traversals) restart_prob = np.zeros(self.metapath_hops * num_traversals)
restart_prob[self.metapath_hops::self.metapath_hops] = termination_prob restart_prob[
self.metapath_hops :: self.metapath_hops
] = termination_prob
restart_prob = F.tensor(restart_prob, dtype=F.float32) restart_prob = F.tensor(restart_prob, dtype=F.float32)
self.restart_prob = F.copy_to(restart_prob, G.device) self.restart_prob = F.copy_to(restart_prob, G.device)
...@@ -116,20 +134,30 @@ class RandomWalkNeighborSampler(object): ...@@ -116,20 +134,30 @@ class RandomWalkNeighborSampler(object):
A homogeneous graph constructed by selecting neighbors for each given node according A homogeneous graph constructed by selecting neighbors for each given node according
to the algorithm above. to the algorithm above.
""" """
seed_nodes = utils.prepare_tensor(self.G, seed_nodes, 'seed_nodes') seed_nodes = utils.prepare_tensor(self.G, seed_nodes, "seed_nodes")
self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes)) self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes))
seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0) seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0)
paths, _ = random_walk( paths, _ = random_walk(
self.G, seed_nodes, metapath=self.full_metapath, restart_prob=self.restart_prob) self.G,
src = F.reshape(paths[:, self.metapath_hops::self.metapath_hops], (-1,)) seed_nodes,
metapath=self.full_metapath,
restart_prob=self.restart_prob,
)
src = F.reshape(
paths[:, self.metapath_hops :: self.metapath_hops], (-1,)
)
dst = F.repeat(paths[:, 0], self.num_traversals, 0) dst = F.repeat(paths[:, 0], self.num_traversals, 0)
src, dst, counts = _select_pinsage_neighbors( src, dst, counts = _select_pinsage_neighbors(
src, dst, (self.num_random_walks * self.num_traversals), self.num_neighbors) src,
dst,
(self.num_random_walks * self.num_traversals),
self.num_neighbors,
)
neighbor_graph = convert.heterograph( neighbor_graph = convert.heterograph(
{(self.ntype, '_E', self.ntype): (src, dst)}, {(self.ntype, "_E", self.ntype): (src, dst)},
{self.ntype: self.G.number_of_nodes(self.ntype)} {self.ntype: self.G.number_of_nodes(self.ntype)},
) )
neighbor_graph.edata[self.weight_column] = counts neighbor_graph.edata[self.weight_column] = counts
...@@ -219,13 +247,30 @@ class PinSAGESampler(RandomWalkNeighborSampler): ...@@ -219,13 +247,30 @@ class PinSAGESampler(RandomWalkNeighborSampler):
Graph Convolutional Neural Networks for Web-Scale Recommender Systems Graph Convolutional Neural Networks for Web-Scale Recommender Systems
Ying et al., 2018, https://arxiv.org/abs/1806.01973 Ying et al., 2018, https://arxiv.org/abs/1806.01973
""" """
def __init__(self, G, ntype, other_type, num_traversals, termination_prob,
num_random_walks, num_neighbors, weight_column='weights'): def __init__(
self,
G,
ntype,
other_type,
num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
weight_column="weights",
):
metagraph = G.metagraph() metagraph = G.metagraph()
fw_etype = list(metagraph[ntype][other_type])[0] fw_etype = list(metagraph[ntype][other_type])[0]
bw_etype = list(metagraph[other_type][ntype])[0] bw_etype = list(metagraph[other_type][ntype])[0]
super().__init__(G, num_traversals, super().__init__(
termination_prob, num_random_walks, num_neighbors, G,
metapath=[fw_etype, bw_etype], weight_column=weight_column) num_traversals,
termination_prob,
num_random_walks,
num_neighbors,
metapath=[fw_etype, bw_etype],
weight_column=weight_column,
)
_init_api('dgl.sampling.pinsage', __name__) _init_api("dgl.sampling.pinsage", __name__)
"""Random walk routines """Random walk routines
""" """
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from ..base import DGLError
from .. import ndarray as nd from .. import ndarray as nd
from .. import utils from .. import utils
from .._ffi.function import _init_api
from ..base import DGLError
__all__ = [ __all__ = ["random_walk", "pack_traces"]
'random_walk',
'pack_traces']
def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob=None,
return_eids=False): def random_walk(
g,
nodes,
*,
metapath=None,
length=None,
prob=None,
restart_prob=None,
return_eids=False
):
"""Generate random walk traces from an array of starting nodes based on the given metapath. """Generate random walk traces from an array of starting nodes based on the given metapath.
Each starting node will have one trace generated, which Each starting node will have one trace generated, which
...@@ -167,15 +174,19 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -167,15 +174,19 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
if metapath is None: if metapath is None:
if n_etypes > 1 or n_ntypes > 1: if n_etypes > 1 or n_ntypes > 1:
raise DGLError("metapath not specified and the graph is not homogeneous.") raise DGLError(
"metapath not specified and the graph is not homogeneous."
)
if length is None: if length is None:
raise ValueError("Please specify either the metapath or the random walk length.") raise ValueError(
"Please specify either the metapath or the random walk length."
)
metapath = [0] * length metapath = [0] * length
else: else:
metapath = [g.get_etype_id(etype) for etype in metapath] metapath = [g.get_etype_id(etype) for etype in metapath]
gidx = g._graph gidx = g._graph
nodes = utils.prepare_tensor(g, nodes, 'nodes') nodes = utils.prepare_tensor(g, nodes, "nodes")
nodes = F.to_dgl_nd(nodes) nodes = F.to_dgl_nd(nodes)
# (Xin) Since metapath array is created by us, safe to skip the check # (Xin) Since metapath array is created by us, safe to skip the check
# and keep it on CPU to make max_nodes sanity check easier. # and keep it on CPU to make max_nodes sanity check easier.
...@@ -196,14 +207,18 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -196,14 +207,18 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
# Actual random walk # Actual random walk
if restart_prob is None: if restart_prob is None:
traces, eids, types = _CAPI_DGLSamplingRandomWalk(gidx, nodes, metapath, p_nd) traces, eids, types = _CAPI_DGLSamplingRandomWalk(
gidx, nodes, metapath, p_nd
)
elif F.is_tensor(restart_prob): elif F.is_tensor(restart_prob):
restart_prob = F.to_dgl_nd(restart_prob) restart_prob = F.to_dgl_nd(restart_prob)
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart( traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
gidx, nodes, metapath, p_nd, restart_prob) gidx, nodes, metapath, p_nd, restart_prob
)
elif isinstance(restart_prob, float): elif isinstance(restart_prob, float):
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart( traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob) gidx, nodes, metapath, p_nd, restart_prob
)
else: else:
raise TypeError("restart_prob should be float or Tensor.") raise TypeError("restart_prob should be float or Tensor.")
...@@ -212,6 +227,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -212,6 +227,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
eids = F.from_dgl_nd(eids) eids = F.from_dgl_nd(eids)
return (traces, eids, types) if return_eids else (traces, types) return (traces, eids, types) if return_eids else (traces, types)
def pack_traces(traces, types): def pack_traces(traces, types):
"""Pack the padded traces returned by ``random_walk()`` into a concatenated array. """Pack the padded traces returned by ``random_walk()`` into a concatenated array.
The padding values (-1) are removed, and the length and offset of each trace is The padding values (-1) are removed, and the length and offset of each trace is
...@@ -276,12 +292,18 @@ def pack_traces(traces, types): ...@@ -276,12 +292,18 @@ def pack_traces(traces, types):
>>> vids[1], vtypes[1] >>> vids[1], vtypes[1]
(tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0])) (tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))
""" """
assert F.is_tensor(traces) and F.context(traces) == F.cpu(), "traces must be a CPU tensor" assert (
assert F.is_tensor(types) and F.context(types) == F.cpu(), "types must be a CPU tensor" F.is_tensor(traces) and F.context(traces) == F.cpu()
), "traces must be a CPU tensor"
assert (
F.is_tensor(types) and F.context(types) == F.cpu()
), "types must be a CPU tensor"
traces = F.to_dgl_nd(traces) traces = F.to_dgl_nd(traces)
types = F.to_dgl_nd(types) types = F.to_dgl_nd(types)
concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types) concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(
traces, types
)
concat_vids = F.from_dgl_nd(concat_vids) concat_vids = F.from_dgl_nd(concat_vids)
concat_types = F.from_dgl_nd(concat_types) concat_types = F.from_dgl_nd(concat_types)
...@@ -290,4 +312,5 @@ def pack_traces(traces, types): ...@@ -290,4 +312,5 @@ def pack_traces(traces, types):
return concat_vids, concat_types, lengths, offsets return concat_vids, concat_types, lengths, offsets
_init_api('dgl.sampling.randomwalks', __name__)
_init_api("dgl.sampling.randomwalks", __name__)
This diff is collapsed.
"""Feature storage classes for DataLoading""" """Feature storage classes for DataLoading"""
from .. import backend as F from .. import backend as F
from .base import * from .base import *
from .numpy import * from .numpy import *
# Defines the name TensorStorage # Defines the name TensorStorage
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == "pytorch":
from .pytorch_tensor import PyTorchTensorStorage as TensorStorage from .pytorch_tensor import PyTorchTensorStorage as TensorStorage
else: else:
from .tensor import BaseTensorStorage as TensorStorage from .tensor import BaseTensorStorage as TensorStorage
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
import threading import threading
STORAGE_WRAPPERS = {} STORAGE_WRAPPERS = {}
def register_storage_wrapper(type_): def register_storage_wrapper(type_):
"""Decorator that associates a type to a ``FeatureStorage`` object. """Decorator that associates a type to a ``FeatureStorage`` object."""
"""
def deco(cls): def deco(cls):
STORAGE_WRAPPERS[type_] = cls STORAGE_WRAPPERS[type_] = cls
return cls return cls
return deco return deco
def wrap_storage(storage): def wrap_storage(storage):
"""Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper`` """Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper``
decorators. decorators.
...@@ -20,11 +23,14 @@ def wrap_storage(storage): ...@@ -20,11 +23,14 @@ def wrap_storage(storage):
if isinstance(storage, type_): if isinstance(storage, type_):
return storage_cls(storage) return storage_cls(storage)
assert isinstance(storage, FeatureStorage), ( assert isinstance(
"The frame column must be a tensor or a FeatureStorage object, got {}" storage, FeatureStorage
.format(type(storage))) ), "The frame column must be a tensor or a FeatureStorage object, got {}".format(
type(storage)
)
return storage return storage
class _FuncWrapper(object): class _FuncWrapper(object):
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
...@@ -32,18 +38,21 @@ class _FuncWrapper(object): ...@@ -32,18 +38,21 @@ class _FuncWrapper(object):
def __call__(self, buf, *args): def __call__(self, buf, *args):
buf[0] = self.func(*args) buf[0] = self.func(*args)
class ThreadedFuture(object): class ThreadedFuture(object):
"""Wraps a function into a future asynchronously executed by a Python """Wraps a function into a future asynchronously executed by a Python
``threading.Thread`. The function is being executed upon instantiation of ``threading.Thread`. The function is being executed upon instantiation of
this object. this object.
""" """
def __init__(self, target, args): def __init__(self, target, args):
self.buf = [None] self.buf = [None]
thread = threading.Thread( thread = threading.Thread(
target=_FuncWrapper(target), target=_FuncWrapper(target),
args=[self.buf] + list(args), args=[self.buf] + list(args),
daemon=True) daemon=True,
)
thread.start() thread.start()
self.thread = thread self.thread = thread
...@@ -52,14 +61,15 @@ class ThreadedFuture(object): ...@@ -52,14 +61,15 @@ class ThreadedFuture(object):
self.thread.join() self.thread.join()
return self.buf[0] return self.buf[0]
class FeatureStorage(object): class FeatureStorage(object):
"""Feature storage object which should support a fetch() operation. It is the """Feature storage object which should support a fetch() operation. It is the
counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous
graphs where the keys are node/edge types. graphs where the keys are node/edge types.
""" """
def requires_ddp(self): def requires_ddp(self):
"""Whether the FeatureStorage requires the DataLoader to set use_ddp. """Whether the FeatureStorage requires the DataLoader to set use_ddp."""
"""
return False return False
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
......
"""Feature storage for ``numpy.memmap`` object.""" """Feature storage for ``numpy.memmap`` object."""
import numpy as np import numpy as np
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper
from .. import backend as F from .. import backend as F
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper
@register_storage_wrapper(np.memmap) @register_storage_wrapper(np.memmap)
class NumpyStorage(FeatureStorage): class NumpyStorage(FeatureStorage):
"""FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object.""" """FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object."""
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
...@@ -17,4 +20,6 @@ class NumpyStorage(FeatureStorage): ...@@ -17,4 +20,6 @@ class NumpyStorage(FeatureStorage):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
return ThreadedFuture(target=self._fetch, args=(indices, device, pin_memory)) return ThreadedFuture(
target=self._fetch, args=(indices, device, pin_memory)
)
"""Feature storages for PyTorch tensors.""" """Feature storages for PyTorch tensors."""
import torch import torch
from ..utils import gather_pinned_tensor_rows
from .base import register_storage_wrapper from .base import register_storage_wrapper
from .tensor import BaseTensorStorage from .tensor import BaseTensorStorage
from ..utils import gather_pinned_tensor_rows
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs): def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs):
result = torch.empty( result = torch.empty(
indices.shape[0], *feature_shape, dtype=tensor.dtype, indices.shape[0],
pin_memory=pin_memory) *feature_shape,
dtype=tensor.dtype,
pin_memory=pin_memory,
)
torch.index_select(tensor, 0, indices, out=result) torch.index_select(tensor, 0, indices, out=result)
kwargs['non_blocking'] = pin_memory kwargs["non_blocking"] = pin_memory
result = result.to(device, **kwargs) result = result.to(device, **kwargs)
return result return result
def _fetch_cuda(indices, tensor, device, **kwargs): def _fetch_cuda(indices, tensor, device, **kwargs):
return torch.index_select(tensor, 0, indices).to(device, **kwargs) return torch.index_select(tensor, 0, indices).to(device, **kwargs)
@register_storage_wrapper(torch.Tensor) @register_storage_wrapper(torch.Tensor)
class PyTorchTensorStorage(BaseTensorStorage): class PyTorchTensorStorage(BaseTensorStorage):
"""Feature storages for slicing a PyTorch tensor.""" """Feature storages for slicing a PyTorch tensor."""
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
device = torch.device(device) device = torch.device(device)
storage_device_type = self.storage.device.type storage_device_type = self.storage.device.type
indices_device_type = indices.device.type indices_device_type = indices.device.type
if storage_device_type != 'cuda': if storage_device_type != "cuda":
if indices_device_type == 'cuda': if indices_device_type == "cuda":
if self.storage.is_pinned(): if self.storage.is_pinned():
return gather_pinned_tensor_rows(self.storage, indices) return gather_pinned_tensor_rows(self.storage, indices)
else: else:
raise ValueError( raise ValueError(
f'Got indices on device {indices.device} whereas the feature tensor ' f"Got indices on device {indices.device} whereas the feature tensor "
f'is on {self.storage.device}. Please either (1) move the graph ' f"is on {self.storage.device}. Please either (1) move the graph "
f'to GPU with to() method, or (2) pin the graph with ' f"to GPU with to() method, or (2) pin the graph with "
f'pin_memory_() method.') f"pin_memory_() method."
)
# CPU to CPU or CUDA - use pin_memory and async transfer if possible # CPU to CPU or CUDA - use pin_memory and async transfer if possible
else: else:
return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device, return _fetch_cpu(
pin_memory, **kwargs) indices,
self.storage,
self.storage.shape[1:],
device,
pin_memory,
**kwargs,
)
else: else:
# CUDA to CUDA or CPU # CUDA to CUDA or CPU
return _fetch_cuda(indices, self.storage, device, **kwargs) return _fetch_cuda(indices, self.storage, device, **kwargs)
"""Feature storages for tensors across different frameworks.""" """Feature storages for tensors across different frameworks."""
from .base import FeatureStorage
from .. import backend as F from .. import backend as F
from .base import FeatureStorage
class BaseTensorStorage(FeatureStorage): class BaseTensorStorage(FeatureStorage):
"""FeatureStorage that synchronously slices features from a tensor and transfers """FeatureStorage that synchronously slices features from a tensor and transfers
it to the given device. it to the given device.
""" """
def __init__(self, tensor): def __init__(self, tensor):
self.storage = tensor self.storage = tensor
def fetch(self, indices, device, pin_memory=False, **kwargs): # pylint: disable=unused-argument def fetch(
self, indices, device, pin_memory=False, **kwargs
): # pylint: disable=unused-argument
return F.copy_to(F.gather_row(tensor, indices), device, **kwargs) return F.copy_to(F.gather_row(tensor, indices), device, **kwargs)
This diff is collapsed.
"""Module for graph traversal methods.""" """Module for graph traversal methods."""
from __future__ import absolute_import from __future__ import absolute_import
from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils from . import utils
from ._ffi.function import _init_api
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
__all__ = ['bfs_nodes_generator', 'bfs_edges_generator', __all__ = [
'topological_nodes_generator', "bfs_nodes_generator",
'dfs_edges_generator', 'dfs_labeled_edges_generator',] "bfs_edges_generator",
"topological_nodes_generator",
"dfs_edges_generator",
"dfs_labeled_edges_generator",
]
def bfs_nodes_generator(graph, source, reverse=False): def bfs_nodes_generator(graph, source, reverse=False):
"""Node frontiers generator using breadth-first search. """Node frontiers generator using breadth-first search.
...@@ -40,10 +45,12 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -40,10 +45,12 @@ def bfs_nodes_generator(graph, source, reverse=False):
>>> list(dgl.bfs_nodes_generator(g, 0)) >>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])] [tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'bfs_nodes_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "bfs_nodes_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -54,6 +61,7 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -54,6 +61,7 @@ def bfs_nodes_generator(graph, source, reverse=False):
node_frontiers = F.split(all_nodes, sections, dim=0) node_frontiers = F.split(all_nodes, sections, dim=0)
return node_frontiers return node_frontiers
def bfs_edges_generator(graph, source, reverse=False): def bfs_edges_generator(graph, source, reverse=False):
"""Edges frontiers generator using breadth-first search. """Edges frontiers generator using breadth-first search.
...@@ -85,10 +93,12 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -85,10 +93,12 @@ def bfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.bfs_edges_generator(g, 0)) >>> list(dgl.bfs_edges_generator(g, 0))
[tensor([0]), tensor([1, 2]), tensor([4, 5])] [tensor([0]), tensor([1, 2]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'bfs_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "bfs_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -99,6 +109,7 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -99,6 +109,7 @@ def bfs_edges_generator(graph, source, reverse=False):
edge_frontiers = F.split(all_edges, sections, dim=0) edge_frontiers = F.split(all_edges, sections, dim=0)
return edge_frontiers return edge_frontiers
def topological_nodes_generator(graph, reverse=False): def topological_nodes_generator(graph, reverse=False):
"""Node frontiers generator using topological traversal. """Node frontiers generator using topological traversal.
...@@ -127,10 +138,12 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -127,10 +138,12 @@ def topological_nodes_generator(graph, reverse=False):
>>> list(dgl.topological_nodes_generator(g)) >>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])] [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'topological_nodes_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "topological_nodes_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse) ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse)
...@@ -139,6 +152,7 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -139,6 +152,7 @@ def topological_nodes_generator(graph, reverse=False):
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0) return F.split(all_nodes, sections, dim=0)
def dfs_edges_generator(graph, source, reverse=False): def dfs_edges_generator(graph, source, reverse=False):
"""Edge frontiers generator using depth-first-search (DFS). """Edge frontiers generator using depth-first-search (DFS).
...@@ -176,10 +190,12 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -176,10 +190,12 @@ def dfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.dfs_edges_generator(g, 0)) >>> list(dgl.dfs_edges_generator(g, 0))
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])] [tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'dfs_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "dfs_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -189,13 +205,15 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -189,13 +205,15 @@ def dfs_edges_generator(graph, source, reverse=False):
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
def dfs_labeled_edges_generator( def dfs_labeled_edges_generator(
graph, graph,
source, source,
reverse=False, reverse=False,
has_reverse_edge=False, has_reverse_edge=False,
has_nontree_edge=False, has_nontree_edge=False,
return_labels=True): return_labels=True,
):
"""Produce edges in a depth-first-search (DFS) labeled by type. """Produce edges in a depth-first-search (DFS) labeled by type.
There are three labels: FORWARD(0), REVERSE(1), NONTREE(2) There are three labels: FORWARD(0), REVERSE(1), NONTREE(2)
...@@ -252,10 +270,12 @@ def dfs_labeled_edges_generator( ...@@ -252,10 +270,12 @@ def dfs_labeled_edges_generator(
(tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])), (tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2])) (tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
""" """
assert isinstance(graph, DGLHeteroGraph), \ assert isinstance(
'DGLGraph is deprecated, Please use DGLHeteroGraph' graph, DGLHeteroGraph
assert len(graph.canonical_etypes) == 1, \ ), "DGLGraph is deprecated, Please use DGLHeteroGraph"
'dfs_labeled_edges_generator only support homogeneous graph' assert (
len(graph.canonical_etypes) == 1
), "dfs_labeled_edges_generator only support homogeneous graph"
# Workaround before support for GPU graph # Workaround before support for GPU graph
gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu())) gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))
source = utils.toindex(source, dtype=graph._idtype_str) source = utils.toindex(source, dtype=graph._idtype_str)
...@@ -265,16 +285,20 @@ def dfs_labeled_edges_generator( ...@@ -265,16 +285,20 @@ def dfs_labeled_edges_generator(
reverse, reverse,
has_reverse_edge, has_reverse_edge,
has_nontree_edge, has_nontree_edge,
return_labels) return_labels,
)
all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor() all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
if return_labels: if return_labels:
all_labels = utils.toindex(ret(1)).tousertensor() all_labels = utils.toindex(ret(1)).tousertensor()
sections = utils.toindex(ret(2)).tonumpy().tolist() sections = utils.toindex(ret(2)).tonumpy().tolist()
return (F.split(all_edges, sections, dim=0), return (
F.split(all_labels, sections, dim=0)) F.split(all_edges, sections, dim=0),
F.split(all_labels, sections, dim=0),
)
else: else:
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
_init_api("dgl.traversal") _init_api("dgl.traversal")
"""Internal utilities.""" """Internal utilities."""
from .internal import *
from .data import *
from .checks import * from .checks import *
from .shared_mem import * from .data import *
from .filter import *
from .exception import * from .exception import *
from .filter import *
from .internal import *
from .pin_memory import * from .pin_memory import *
from .shared_mem import *
"""Checking and logging utilities.""" """Checking and logging utilities."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import, division from __future__ import absolute_import, division
from collections.abc import Mapping from collections.abc import Mapping
from ..base import DGLError
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
from ..base import DGLError
def prepare_tensor(g, data, name): def prepare_tensor(g, data, name):
"""Convert the data to ID tensor and check its ID type and context. """Convert the data to ID tensor and check its ID type and context.
...@@ -31,27 +33,43 @@ def prepare_tensor(g, data, name): ...@@ -31,27 +33,43 @@ def prepare_tensor(g, data, name):
""" """
if F.is_tensor(data): if F.is_tensor(data):
if F.dtype(data) != g.idtype: if F.dtype(data) != g.idtype:
raise DGLError(f'Expect argument "{name}" to have data type {g.idtype}. ' raise DGLError(
f'But got {F.dtype(data)}.') f'Expect argument "{name}" to have data type {g.idtype}. '
f"But got {F.dtype(data)}."
)
if F.context(data) != g.device and not g.is_pinned(): if F.context(data) != g.device and not g.is_pinned():
raise DGLError(f'Expect argument "{name}" to have device {g.device}. ' raise DGLError(
f'But got {F.context(data)}.') f'Expect argument "{name}" to have device {g.device}. '
f"But got {F.context(data)}."
)
ret = data ret = data
else: else:
data = F.tensor(data) data = F.tensor(data)
if (not (F.ndim(data) > 0 and F.shape(data)[0] == 0) and # empty tensor if not (
F.dtype(data) not in (F.int32, F.int64)): F.ndim(data) > 0 and F.shape(data)[0] == 0
raise DGLError('Expect argument "{}" to have data type int32 or int64,' ) and F.dtype( # empty tensor
' but got {}.'.format(name, F.dtype(data))) data
) not in (
F.int32,
F.int64,
):
raise DGLError(
'Expect argument "{}" to have data type int32 or int64,'
" but got {}.".format(name, F.dtype(data))
)
ret = F.copy_to(F.astype(data, g.idtype), g.device) ret = F.copy_to(F.astype(data, g.idtype), g.device)
if F.ndim(ret) == 0: if F.ndim(ret) == 0:
ret = F.unsqueeze(ret, 0) ret = F.unsqueeze(ret, 0)
if F.ndim(ret) > 1: if F.ndim(ret) > 1:
raise DGLError('Expect a 1-D tensor for argument "{}". But got {}.'.format( raise DGLError(
name, ret)) 'Expect a 1-D tensor for argument "{}". But got {}.'.format(
name, ret
)
)
return ret return ret
def prepare_tensor_dict(g, data, name): def prepare_tensor_dict(g, data, name):
"""Convert a dictionary of data to a dictionary of ID tensors. """Convert a dictionary of data to a dictionary of ID tensors.
...@@ -70,8 +88,11 @@ def prepare_tensor_dict(g, data, name): ...@@ -70,8 +88,11 @@ def prepare_tensor_dict(g, data, name):
------- -------
dict[str, tensor] dict[str, tensor]
""" """
return {key : prepare_tensor(g, val, '{}["{}"]'.format(name, key)) return {
for key, val in data.items()} key: prepare_tensor(g, val, '{}["{}"]'.format(name, key))
for key, val in data.items()
}
def prepare_tensor_or_dict(g, data, name): def prepare_tensor_or_dict(g, data, name):
"""Convert data to either a tensor or a dictionary depending on input type. """Convert data to either a tensor or a dictionary depending on input type.
...@@ -89,10 +110,14 @@ def prepare_tensor_or_dict(g, data, name): ...@@ -89,10 +110,14 @@ def prepare_tensor_or_dict(g, data, name):
------- -------
tensor or dict[str, tensor] tensor or dict[str, tensor]
""" """
return prepare_tensor_dict(g, data, name) if isinstance(data, Mapping) \ return (
else prepare_tensor(g, data, name) prepare_tensor_dict(g, data, name)
if isinstance(data, Mapping)
else prepare_tensor(g, data, name)
)
def parse_edges_arg_to_eid(g, edges, etid, argname='edges'):
def parse_edges_arg_to_eid(g, edges, etid, argname="edges"):
"""Parse the :attr:`edges` argument and return an edge ID tensor. """Parse the :attr:`edges` argument and return an edge ID tensor.
The resulting edge ID tensor has the same ID type and device of :attr:`g`. The resulting edge ID tensor has the same ID type and device of :attr:`g`.
...@@ -115,13 +140,14 @@ def parse_edges_arg_to_eid(g, edges, etid, argname='edges'): ...@@ -115,13 +140,14 @@ def parse_edges_arg_to_eid(g, edges, etid, argname='edges'):
""" """
if isinstance(edges, tuple): if isinstance(edges, tuple):
u, v = edges u, v = edges
u = prepare_tensor(g, u, '{}[0]'.format(argname)) u = prepare_tensor(g, u, "{}[0]".format(argname))
v = prepare_tensor(g, v, '{}[1]'.format(argname)) v = prepare_tensor(g, v, "{}[1]".format(argname))
eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid]) eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid])
else: else:
eid = prepare_tensor(g, edges, argname) eid = prepare_tensor(g, edges, argname)
return eid return eid
def check_all_same_idtype(glist, name): def check_all_same_idtype(glist, name):
"""Check all the graphs have the same idtype.""" """Check all the graphs have the same idtype."""
if len(glist) == 0: if len(glist) == 0:
...@@ -129,8 +155,12 @@ def check_all_same_idtype(glist, name): ...@@ -129,8 +155,12 @@ def check_all_same_idtype(glist, name):
idtype = glist[0].idtype idtype = glist[0].idtype
for i, g in enumerate(glist): for i, g in enumerate(glist):
if g.idtype != idtype: if g.idtype != idtype:
raise DGLError('Expect {}[{}] to have {} type ID, but got {}.'.format( raise DGLError(
name, i, idtype, g.idtype)) "Expect {}[{}] to have {} type ID, but got {}.".format(
name, i, idtype, g.idtype
)
)
def check_device(data, device): def check_device(data, device):
"""Check if data is on the target device. """Check if data is on the target device.
...@@ -152,6 +182,7 @@ def check_device(data, device): ...@@ -152,6 +182,7 @@ def check_device(data, device):
return False return False
return True return True
def check_all_same_device(glist, name): def check_all_same_device(glist, name):
"""Check all the graphs have the same device.""" """Check all the graphs have the same device."""
if len(glist) == 0: if len(glist) == 0:
...@@ -159,8 +190,12 @@ def check_all_same_device(glist, name): ...@@ -159,8 +190,12 @@ def check_all_same_device(glist, name):
device = glist[0].device device = glist[0].device
for i, g in enumerate(glist): for i, g in enumerate(glist):
if g.device != device: if g.device != device:
raise DGLError('Expect {}[{}] to be on device {}, but got {}.'.format( raise DGLError(
name, i, device, g.device)) "Expect {}[{}] to be on device {}, but got {}.".format(
name, i, device, g.device
)
)
def check_all_same_schema(schemas, name): def check_all_same_schema(schemas, name):
"""Check the list of schemas are the same.""" """Check the list of schemas are the same."""
...@@ -170,9 +205,12 @@ def check_all_same_schema(schemas, name): ...@@ -170,9 +205,12 @@ def check_all_same_schema(schemas, name):
for i, schema in enumerate(schemas): for i, schema in enumerate(schemas):
if schema != schemas[0]: if schema != schemas[0]:
raise DGLError( raise DGLError(
'Expect all graphs to have the same schema on {}, ' "Expect all graphs to have the same schema on {}, "
'but graph {} got\n\t{}\nwhich is different from\n\t{}.'.format( "but graph {} got\n\t{}\nwhich is different from\n\t{}.".format(
name, i, schema, schemas[0])) name, i, schema, schemas[0]
)
)
def check_all_same_schema_for_keys(schemas, keys, name): def check_all_same_schema_for_keys(schemas, keys, name):
"""Check the list of schemas are the same on the given keys.""" """Check the list of schemas are the same on the given keys."""
...@@ -184,9 +222,9 @@ def check_all_same_schema_for_keys(schemas, keys, name): ...@@ -184,9 +222,9 @@ def check_all_same_schema_for_keys(schemas, keys, name):
for i, schema in enumerate(schemas): for i, schema in enumerate(schemas):
if not keys.issubset(schema.keys()): if not keys.issubset(schema.keys()):
raise DGLError( raise DGLError(
'Expect all graphs to have keys {} on {}, ' "Expect all graphs to have keys {} on {}, "
'but graph {} got keys {}.'.format( "but graph {} got keys {}.".format(keys, name, i, schema.keys())
keys, name, i, schema.keys())) )
if head is None: if head is None:
head = {k: schema[k] for k in keys} head = {k: schema[k] for k in keys}
...@@ -194,9 +232,12 @@ def check_all_same_schema_for_keys(schemas, keys, name): ...@@ -194,9 +232,12 @@ def check_all_same_schema_for_keys(schemas, keys, name):
target = {k: schema[k] for k in keys} target = {k: schema[k] for k in keys}
if target != head: if target != head:
raise DGLError( raise DGLError(
'Expect all graphs to have the same schema for keys {} on {}, ' "Expect all graphs to have the same schema for keys {} on {}, "
'but graph {} got \n\t{}\n which is different from\n\t{}.'.format( "but graph {} got \n\t{}\n which is different from\n\t{}.".format(
keys, name, i, target, head)) keys, name, i, target, head
)
)
def check_valid_idtype(idtype): def check_valid_idtype(idtype):
"""Check whether the value of the idtype argument is valid (int32/int64) """Check whether the value of the idtype argument is valid (int32/int64)
...@@ -207,8 +248,11 @@ def check_valid_idtype(idtype): ...@@ -207,8 +248,11 @@ def check_valid_idtype(idtype):
The framework object of a data type. The framework object of a data type.
""" """
if idtype not in [None, F.int32, F.int64]: if idtype not in [None, F.int32, F.int64]:
raise DGLError('Expect idtype to be a framework object of int32/int64, ' raise DGLError(
'got {}'.format(idtype)) "Expect idtype to be a framework object of int32/int64, "
"got {}".format(idtype)
)
def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
"""Checks whether an edge list is in ascending src-major order (e.g., first """Checks whether an edge list is in ascending src-major order (e.g., first
...@@ -234,9 +278,9 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): ...@@ -234,9 +278,9 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
# for some versions of MXNET and TensorFlow, num_src and num_dst get # for some versions of MXNET and TensorFlow, num_src and num_dst get
# incorrectly marked as floats, so force them as integers here # incorrectly marked as floats, so force them as integers here
if num_src is None: if num_src is None:
num_src = int(F.as_scalar(F.max(src, dim=0)+1)) num_src = int(F.as_scalar(F.max(src, dim=0) + 1))
if num_dst is None: if num_dst is None:
num_dst = int(F.as_scalar(F.max(dst, dim=0)+1)) num_dst = int(F.as_scalar(F.max(dst, dim=0) + 1))
src = F.zerocopy_to_dgl_ndarray(src) src = F.zerocopy_to_dgl_ndarray(src)
dst = F.zerocopy_to_dgl_ndarray(dst) dst = F.zerocopy_to_dgl_ndarray(dst)
...@@ -247,4 +291,5 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None): ...@@ -247,4 +291,5 @@ def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
return row_sorted, col_sorted return row_sorted, col_sorted
_init_api("dgl.utils.checks") _init_api("dgl.utils.checks")
"""Data utilities.""" """Data utilities."""
from collections import namedtuple from collections import namedtuple
import scipy as sp
import networkx as nx import networkx as nx
import scipy as sp
from ..base import DGLError
from .. import backend as F from .. import backend as F
from ..base import DGLError
from . import checks from . import checks
def elist2tensor(elist, idtype): def elist2tensor(elist, idtype):
"""Function to convert an edge list to edge tensors. """Function to convert an edge list to edge tensors.
...@@ -31,6 +33,7 @@ def elist2tensor(elist, idtype): ...@@ -31,6 +33,7 @@ def elist2tensor(elist, idtype):
v = list(v) v = list(v)
return F.tensor(u, idtype), F.tensor(v, idtype) return F.tensor(u, idtype), F.tensor(v, idtype)
def scipy2tensor(spmat, idtype): def scipy2tensor(spmat, idtype):
"""Function to convert a scipy matrix to a sparse adjacency matrix tuple. """Function to convert a scipy matrix to a sparse adjacency matrix tuple.
...@@ -49,7 +52,7 @@ def scipy2tensor(spmat, idtype): ...@@ -49,7 +52,7 @@ def scipy2tensor(spmat, idtype):
A tuple containing the format as well as the list of tensors representing A tuple containing the format as well as the list of tensors representing
the sparse matrix. the sparse matrix.
""" """
if spmat.format in ['csr', 'csc']: if spmat.format in ["csr", "csc"]:
indptr = F.tensor(spmat.indptr, idtype) indptr = F.tensor(spmat.indptr, idtype)
indices = F.tensor(spmat.indices, idtype) indices = F.tensor(spmat.indices, idtype)
data = F.tensor([], idtype) data = F.tensor([], idtype)
...@@ -58,7 +61,8 @@ def scipy2tensor(spmat, idtype): ...@@ -58,7 +61,8 @@ def scipy2tensor(spmat, idtype):
spmat = spmat.tocoo() spmat = spmat.tocoo()
row = F.tensor(spmat.row, idtype) row = F.tensor(spmat.row, idtype)
col = F.tensor(spmat.col, idtype) col = F.tensor(spmat.col, idtype)
return SparseAdjTuple('coo', (row, col)) return SparseAdjTuple("coo", (row, col))
def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
"""Function to convert a networkx graph to edge tensors. """Function to convert a networkx graph to edge tensors.
...@@ -82,7 +86,7 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -82,7 +86,7 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
nx_graph = nx_graph.to_directed() nx_graph = nx_graph.to_directed()
# Relabel nodes using consecutive integers # Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted') nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering="sorted")
has_edge_id = edge_id_attr_name is not None has_edge_id = edge_id_attr_name is not None
if has_edge_id: if has_edge_id:
...@@ -92,8 +96,10 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -92,8 +96,10 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = int(attr[edge_id_attr_name]) eid = int(attr[edge_id_attr_name])
if eid < 0 or eid >= nx_graph.number_of_edges(): if eid < 0 or eid >= nx_graph.number_of_edges():
raise DGLError('Expect edge IDs to be a non-negative integer smaller than {:d}, ' raise DGLError(
'got {:d}'.format(num_edges, eid)) "Expect edge IDs to be a non-negative integer smaller than {:d}, "
"got {:d}".format(num_edges, eid)
)
src[eid] = u src[eid] = u
dst[eid] = v dst[eid] = v
else: else:
...@@ -106,7 +112,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -106,7 +112,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
dst = F.tensor(dst, idtype) dst = F.tensor(dst, idtype)
return src, dst return src, dst
SparseAdjTuple = namedtuple('SparseAdjTuple', ['format', 'arrays'])
SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"])
def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
"""Function to convert various types of data to edge tensors and infer """Function to convert various types of data to edge tensors and infer
...@@ -151,33 +159,42 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -151,33 +159,42 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
if isinstance(data, tuple): if isinstance(data, tuple):
if not isinstance(data[0], str): if not isinstance(data[0], str):
# (row, col) format, convert to ('coo', (row, col)) # (row, col) format, convert to ('coo', (row, col))
data = ('coo', data) data = ("coo", data)
data = SparseAdjTuple(*data) data = SparseAdjTuple(*data)
if idtype is None and \ if idtype is None and not (
not (isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])): isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])
):
# preferred default idtype is int64 # preferred default idtype is int64
# if data is tensor and idtype is None, infer the idtype from tensor # if data is tensor and idtype is None, infer the idtype from tensor
idtype = F.int64 idtype = F.int64
checks.check_valid_idtype(idtype) checks.check_valid_idtype(idtype)
if isinstance(data, SparseAdjTuple) and (not all(F.is_tensor(a) for a in data.arrays)): if isinstance(data, SparseAdjTuple) and (
not all(F.is_tensor(a) for a in data.arrays)
):
# (Iterable, Iterable) type data, convert it to (Tensor, Tensor) # (Iterable, Iterable) type data, convert it to (Tensor, Tensor)
if len(data.arrays[0]) == 0: if len(data.arrays[0]) == 0:
# force idtype for empty list # force idtype for empty list
data = SparseAdjTuple(data.format, tuple(F.tensor(a, idtype) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.tensor(a, idtype) for a in data.arrays)
)
else: else:
# convert the iterable to tensor and keep its native data type so we can check # convert the iterable to tensor and keep its native data type so we can check
# its validity later # its validity later
data = SparseAdjTuple(data.format, tuple(F.tensor(a) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.tensor(a) for a in data.arrays)
)
if isinstance(data, SparseAdjTuple): if isinstance(data, SparseAdjTuple):
if idtype is not None: if idtype is not None:
data = SparseAdjTuple(data.format, tuple(F.astype(a, idtype) for a in data.arrays)) data = SparseAdjTuple(
data.format, tuple(F.astype(a, idtype) for a in data.arrays)
)
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, list): elif isinstance(data, list):
src, dst = elist2tensor(data, idtype) src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple('coo', (src, dst)) data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily. # We can get scipy matrix's number of rows and columns easily.
...@@ -186,23 +203,31 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -186,23 +203,31 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily. # We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
edge_id_attr_name = kwargs.get('edge_id_attr_name', None) edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite: if bipartite:
top_map = kwargs.get('top_map') top_map = kwargs.get("top_map")
bottom_map = kwargs.get('bottom_map') bottom_map = kwargs.get("bottom_map")
src, dst = networkxbipartite2tensors( src, dst = networkxbipartite2tensors(
data, idtype, top_map=top_map, data,
bottom_map=bottom_map, edge_id_attr_name=edge_id_attr_name) idtype,
top_map=top_map,
bottom_map=bottom_map,
edge_id_attr_name=edge_id_attr_name,
)
else: else:
src, dst = networkx2tensor( src, dst = networkx2tensor(
data, idtype, edge_id_attr_name=edge_id_attr_name) data, idtype, edge_id_attr_name=edge_id_attr_name
data = SparseAdjTuple('coo', (src, dst)) )
data = SparseAdjTuple("coo", (src, dst))
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError("Unsupported graph data type:", type(data))
return data, num_src, num_dst return data, num_src, num_dst
def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_attr_name=None):
def networkxbipartite2tensors(
nx_graph, idtype, top_map, bottom_map, edge_id_attr_name=None
):
"""Function to convert a networkx bipartite to edge tensors. """Function to convert a networkx bipartite to edge tensors.
Parameters Parameters
...@@ -234,15 +259,21 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att ...@@ -234,15 +259,21 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att
dst = [0] * num_edges dst = [0] * num_edges
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
if u not in top_map: if u not in top_map:
raise DGLError('Expect the node {} to have attribute bipartite=0 ' raise DGLError(
'with edge {}'.format(u, (u, v))) "Expect the node {} to have attribute bipartite=0 "
"with edge {}".format(u, (u, v))
)
if v not in bottom_map: if v not in bottom_map:
raise DGLError('Expect the node {} to have attribute bipartite=1 ' raise DGLError(
'with edge {}'.format(v, (u, v))) "Expect the node {} to have attribute bipartite=1 "
"with edge {}".format(v, (u, v))
)
eid = int(attr[edge_id_attr_name]) eid = int(attr[edge_id_attr_name])
if eid < 0 or eid >= nx_graph.number_of_edges(): if eid < 0 or eid >= nx_graph.number_of_edges():
raise DGLError('Expect edge IDs to be a non-negative integer smaller than {:d}, ' raise DGLError(
'got {:d}'.format(num_edges, eid)) "Expect edge IDs to be a non-negative integer smaller than {:d}, "
"got {:d}".format(num_edges, eid)
)
src[eid] = top_map[u] src[eid] = top_map[u]
dst[eid] = bottom_map[v] dst[eid] = bottom_map[v]
else: else:
...@@ -251,17 +282,22 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att ...@@ -251,17 +282,22 @@ def networkxbipartite2tensors(nx_graph, idtype, top_map, bottom_map, edge_id_att
for e in nx_graph.edges: for e in nx_graph.edges:
u, v = e[0], e[1] u, v = e[0], e[1]
if u not in top_map: if u not in top_map:
raise DGLError('Expect the node {} to have attribute bipartite=0 ' raise DGLError(
'with edge {}'.format(u, (u, v))) "Expect the node {} to have attribute bipartite=0 "
"with edge {}".format(u, (u, v))
)
if v not in bottom_map: if v not in bottom_map:
raise DGLError('Expect the node {} to have attribute bipartite=1 ' raise DGLError(
'with edge {}'.format(v, (u, v))) "Expect the node {} to have attribute bipartite=1 "
"with edge {}".format(v, (u, v))
)
src.append(top_map[u]) src.append(top_map[u])
dst.append(bottom_map[v]) dst.append(bottom_map[v])
src = F.tensor(src, dtype=idtype) src = F.tensor(src, dtype=idtype)
dst = F.tensor(dst, dtype=idtype) dst = F.tensor(dst, dtype=idtype)
return src, dst return src, dst
def infer_num_nodes(data, bipartite=False): def infer_num_nodes(data, bipartite=False):
"""Function for inferring the number of nodes. """Function for inferring the number of nodes.
...@@ -292,25 +328,35 @@ def infer_num_nodes(data, bipartite=False): ...@@ -292,25 +328,35 @@ def infer_num_nodes(data, bipartite=False):
""" """
if isinstance(data, tuple) and len(data) == 2: if isinstance(data, tuple) and len(data) == 2:
if not isinstance(data[0], str): if not isinstance(data[0], str):
raise TypeError('Expected sparse format as a str, but got %s' % type(data[0])) raise TypeError(
"Expected sparse format as a str, but got %s" % type(data[0])
)
if data[0] == 'coo': if data[0] == "coo":
# ('coo', (src, dst)) format # ('coo', (src, dst)) format
u, v = data[1] u, v = data[1]
nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0 nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0
ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0 ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0
elif data[0] == 'csr': elif data[0] == "csr":
# ('csr', (indptr, indices, eids)) format # ('csr', (indptr, indices, eids)) format
indptr, indices, _ = data[1] indptr, indices, _ = data[1]
nsrc = F.shape(indptr)[0] - 1 nsrc = F.shape(indptr)[0] - 1
ndst = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0 ndst = (
elif data[0] == 'csc': F.as_scalar(F.max(indices, dim=0)) + 1
if len(indices) > 0
else 0
)
elif data[0] == "csc":
# ('csc', (indptr, indices, eids)) format # ('csc', (indptr, indices, eids)) format
indptr, indices, _ = data[1] indptr, indices, _ = data[1]
ndst = F.shape(indptr)[0] - 1 ndst = F.shape(indptr)[0] - 1
nsrc = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0 nsrc = (
F.as_scalar(F.max(indices, dim=0)) + 1
if len(indices) > 0
else 0
)
else: else:
raise ValueError('unknown format %s' % data[0]) raise ValueError("unknown format %s" % data[0])
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
nsrc, ndst = data.shape[0], data.shape[1] nsrc, ndst = data.shape[0], data.shape[1]
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
...@@ -319,7 +365,9 @@ def infer_num_nodes(data, bipartite=False): ...@@ -319,7 +365,9 @@ def infer_num_nodes(data, bipartite=False):
elif not bipartite: elif not bipartite:
nsrc = ndst = data.number_of_nodes() nsrc = ndst = data.number_of_nodes()
else: else:
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0}) nsrc = len(
{n for n, d in data.nodes(data=True) if d["bipartite"] == 0}
)
ndst = data.number_of_nodes() - nsrc ndst = data.number_of_nodes() - nsrc
else: else:
return None return None
...@@ -327,6 +375,7 @@ def infer_num_nodes(data, bipartite=False): ...@@ -327,6 +375,7 @@ def infer_num_nodes(data, bipartite=False):
nsrc = ndst = max(nsrc, ndst) nsrc = ndst = max(nsrc, ndst)
return nsrc, ndst return nsrc, ndst
def to_device(data, device): def to_device(data, device):
"""Transfer the tensor or dictionary of tensors to the given device. """Transfer the tensor or dictionary of tensors to the given device.
......
...@@ -16,14 +16,17 @@ import traceback ...@@ -16,14 +16,17 @@ import traceback
# and the frame (which holds reference to all the object in its temporary scope) # and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback. # holding reference the traceback.
class KeyErrorMessage(str): class KeyErrorMessage(str):
r"""str subclass that returns itself in repr""" r"""str subclass that returns itself in repr"""
def __repr__(self): # pylint: disable=invalid-repr-returned
def __repr__(self): # pylint: disable=invalid-repr-returned
return self return self
class ExceptionWrapper(object): class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads""" r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info=None, where="in background"): def __init__(self, exc_info=None, where="in background"):
# It is important that we don't store exc_info, see # It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ] # NOTE [ Python Traceback Reference Cycle Problem ]
...@@ -38,7 +41,8 @@ class ExceptionWrapper(object): ...@@ -38,7 +41,8 @@ class ExceptionWrapper(object):
# Format a message such as: "Caught ValueError in DataLoader worker # Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback. # process 2. Original Traceback:", followed by the traceback.
msg = "Caught {} {}.\nOriginal {}".format( msg = "Caught {} {}.\nOriginal {}".format(
self.exc_type.__name__, self.where, self.exc_msg) self.exc_type.__name__, self.where, self.exc_msg
)
if self.exc_type == KeyError: if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This # KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python # makes stack traces unreadable. It will not be changed in Python
......
"""Utilities for finding overlap or missing items in arrays.""" """Utilities for finding overlap or missing items in arrays."""
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
class Filter(object): class Filter(object):
...@@ -20,6 +19,7 @@ class Filter(object): ...@@ -20,6 +19,7 @@ class Filter(object):
>>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda'))) >>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda')))
tensor([0,2], device='cuda') tensor([0,2], device='cuda')
""" """
def __init__(self, ids): def __init__(self, ids):
"""Create a new filter from a given set of IDs. This currently is only """Create a new filter from a given set of IDs. This currently is only
implemented for the GPU. implemented for the GPU.
...@@ -30,7 +30,8 @@ class Filter(object): ...@@ -30,7 +30,8 @@ class Filter(object):
The unique set of IDs to keep in the filter. The unique set of IDs to keep in the filter.
""" """
self._filter = _CAPI_DGLFilterCreateFromSet( self._filter = _CAPI_DGLFilterCreateFromSet(
F.zerocopy_to_dgl_ndarray(ids)) F.zerocopy_to_dgl_ndarray(ids)
)
def find_included_indices(self, test): def find_included_indices(self, test):
"""Find the index of the IDs in `test` that are in this filter. """Find the index of the IDs in `test` that are in this filter.
...@@ -45,9 +46,11 @@ class Filter(object): ...@@ -45,9 +46,11 @@ class Filter(object):
IdArray IdArray
The index of IDs in `test` that are also in this filter. The index of IDs in `test` that are also in this filter.
""" """
return F.zerocopy_from_dgl_ndarray( \ return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLFilterFindIncludedIndices( \ _CAPI_DGLFilterFindIncludedIndices(
self._filter, F.zerocopy_to_dgl_ndarray(test))) self._filter, F.zerocopy_to_dgl_ndarray(test)
)
)
def find_excluded_indices(self, test): def find_excluded_indices(self, test):
"""Find the index of the IDs in `test` that are not in this filter. """Find the index of the IDs in `test` that are not in this filter.
...@@ -62,8 +65,11 @@ class Filter(object): ...@@ -62,8 +65,11 @@ class Filter(object):
IdArray IdArray
The index of IDs in `test` that are not in this filter. The index of IDs in `test` that are not in this filter.
""" """
return F.zerocopy_from_dgl_ndarray( \ return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLFilterFindExcludedIndices( \ _CAPI_DGLFilterFindExcludedIndices(
self._filter, F.zerocopy_to_dgl_ndarray(test))) self._filter, F.zerocopy_to_dgl_ndarray(test)
)
)
_init_api("dgl.utils.filter") _init_api("dgl.utils.filter")
This diff is collapsed.
"""Utility functions related to pinned memory tensors.""" """Utility functions related to pinned memory tensors."""
from ..base import DGLError
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import DGLError
def pin_memory_inplace(tensor): def pin_memory_inplace(tensor):
"""Register the tensor into pinned memory in-place (i.e. without copying). """Register the tensor into pinned memory in-place (i.e. without copying).
...@@ -19,9 +20,11 @@ def pin_memory_inplace(tensor): ...@@ -19,9 +20,11 @@ def pin_memory_inplace(tensor):
The dgl.ndarray object that holds the pinning status and shares the same The dgl.ndarray object that holds the pinning status and shares the same
underlying data with the tensor. underlying data with the tensor.
""" """
if F.backend_name in ['mxnet', 'tensorflow']: if F.backend_name in ["mxnet", "tensorflow"]:
raise DGLError("The {} backend does not support pinning " \ raise DGLError(
"tensors in-place.".format(F.backend_name)) "The {} backend does not support pinning "
"tensors in-place.".format(F.backend_name)
)
# needs to be writable to allow in-place modification # needs to be writable to allow in-place modification
try: try:
...@@ -31,6 +34,7 @@ def pin_memory_inplace(tensor): ...@@ -31,6 +34,7 @@ def pin_memory_inplace(tensor):
except Exception as e: except Exception as e:
raise DGLError("Failed to pin memory in-place due to: {}".format(e)) raise DGLError("Failed to pin memory in-place due to: {}".format(e))
def gather_pinned_tensor_rows(tensor, rows): def gather_pinned_tensor_rows(tensor, rows):
"""Directly gather rows from a CPU tensor given an indices array on CUDA devices, """Directly gather rows from a CPU tensor given an indices array on CUDA devices,
and returns the result on the same CUDA device without copying. and returns the result on the same CUDA device without copying.
...@@ -47,7 +51,10 @@ def gather_pinned_tensor_rows(tensor, rows): ...@@ -47,7 +51,10 @@ def gather_pinned_tensor_rows(tensor, rows):
Tensor Tensor
The result with the same device as :attr:`rows`. The result with the same device as :attr:`rows`.
""" """
return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))) return F.from_dgl_nd(
_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))
)
def scatter_pinned_tensor_rows(dest, rows, source): def scatter_pinned_tensor_rows(dest, rows, source):
"""Directly scatter rows from a GPU tensor given an indices array on CUDA devices, """Directly scatter rows from a GPU tensor given an indices array on CUDA devices,
...@@ -62,8 +69,9 @@ def scatter_pinned_tensor_rows(dest, rows, source): ...@@ -62,8 +69,9 @@ def scatter_pinned_tensor_rows(dest, rows, source):
source : Tensor source : Tensor
The tensor on the GPU to scatter rows from. The tensor on the GPU to scatter rows from.
""" """
_CAPI_DGLIndexScatterGPUToCPU(F.to_dgl_nd(dest), F.to_dgl_nd(rows), _CAPI_DGLIndexScatterGPUToCPU(
F.to_dgl_nd(source)) F.to_dgl_nd(dest), F.to_dgl_nd(rows), F.to_dgl_nd(source)
)
_init_api("dgl.ndarray.uvm", __name__) _init_api("dgl.ndarray.uvm", __name__)
"""Shared memory utilities. """Shared memory utilities.
For compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the For compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the
content has been moved to ``dgl.ndarray`` module. content has been moved to ``dgl.ndarray`` module.
""" """
from ..ndarray import get_shared_mem_array, create_shared_mem_array # pylint: disable=unused-import from ..ndarray import ( # pylint: disable=unused-import
create_shared_mem_array,
get_shared_mem_array,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment