"vscode:/vscode.git/clone" did not exist on "756fdd8e909dfedd4b1cfbe2ad3860b1df6508ec"
Unverified Commit 701b4fcc authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampling] New sampling pipeline plus asynchronous prefetching (#3665)

* initial update

* more

* more

* multi-gpu example

* cluster gcn, finalize homogeneous

* more explanation

* fix

* bunch of fixes

* fix

* RGAT example and more fixes

* shadow-gnn sampler and some changes in unit test

* fix

* wth

* more fixes

* remove shadow+node/edge dataloader tests for possible ux changes

* lints

* add legacy dataloading import just in case

* fix

* update pylint for f-strings

* fix

* lint

* lint

* lint again

* cherry-picking commit fa9f494

* oops

* fix

* add sample_neighbors in dist_graph

* fix

* lint

* fix

* fix

* fix

* fix tutorial

* fix

* fix

* fix

* fix warning

* remove debug

* add get_foo_storage apis

* lint
parent 5152a879
......@@ -10,3 +10,4 @@ from .pinsage import *
from .neighbor import *
from .node2vec_randomwalk import *
from .negative import *
from . import utils
......@@ -3,6 +3,8 @@
from numpy.polynomial import polynomial
from .._ffi.function import _init_api
from .. import backend as F
from .. import utils
from ..heterograph import DGLHeteroGraph
__all__ = [
'global_uniform_negative_sampling']
......@@ -99,5 +101,7 @@ def global_uniform_negative_sampling(
src, dst = _CAPI_DGLGlobalUniformNegativeSampling(
g._graph, etype_id, num_samples, 3, exclude_self_loops, replace, redundancy)
return F.from_dgl_nd(src), F.from_dgl_nd(dst)
DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func(
global_uniform_negative_sampling)
_init_api('dgl.sampling.negative', __name__)
......@@ -6,6 +6,7 @@ from ..base import DGLError, EID
from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd
from .. import utils
from .utils import EidExcluder
__all__ = [
'sample_etype_neighbors',
......@@ -15,7 +16,7 @@ __all__ = [
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None,
replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False,
_dist_training=False):
_dist_training=False, output_device=None):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -77,6 +78,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
A hint telling whether the etypes are already sorted.
(Default: False)
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -142,10 +145,13 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i]
return ret
return ret if output_device is None else ret.to(output_device)
DGLHeteroGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors)
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None):
copy_ndata=True, copy_edata=True, _dist_training=False,
exclude_edges=None, output_device=None):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -210,12 +216,13 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
Internal argument. Do not use.
(Default: False)
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
DGLGraph
A sampled subgraph containing only the sampled neighboring edges, with the
same device as the input graph.
A sampled subgraph containing only the sampled neighboring edges.
Notes
-----
......@@ -280,6 +287,22 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
tensor([False, False, False])
"""
if g.device == F.cpu():
frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace,
copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges)
else:
frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace,
copy_ndata=copy_ndata, copy_edata=copy_edata)
if exclude_edges is not None:
eid_excluder = EidExcluder(exclude_edges)
frontier = eid_excluder(frontier)
return frontier if output_device is None else frontier.to(output_device)
def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
copy_ndata=True, copy_edata=True, _dist_training=False,
exclude_edges=None):
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
......@@ -357,9 +380,11 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
return ret
DGLHeteroGraph.sample_neighbors = utils.alias_func(sample_neighbors)
def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
tag_offset_name='_TAG_OFFSET', replace=False,
copy_ndata=True, copy_edata=True):
copy_ndata=True, copy_edata=True, output_device=None):
r"""Sample neighboring edges of the given nodes and return the induced subgraph, where each
neighbor's probability to be picked is determined by its tag.
......@@ -439,6 +464,8 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
edge features.
(Default: True)
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -523,11 +550,12 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
utils.set_new_frames(ret, edge_frames=edge_frames)
ret.edata[EID] = induced_edges[0]
return ret
return ret if output_device is None else ret.to(output_device)
DGLHeteroGraph.sample_neighbors_biased = utils.alias_func(sample_neighbors_biased)
def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
copy_ndata=True, copy_edata=True):
copy_ndata=True, copy_edata=True, output_device=None):
"""Select the neighboring edges with k-largest (or k-smallest) weights of the given
nodes and return the induced subgraph.
......@@ -581,6 +609,8 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
edge features.
(Default: True)
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -655,6 +685,8 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
if copy_edata:
edge_frames = utils.extract_edge_subframes(g, induced_edges)
utils.set_new_frames(ret, edge_frames=edge_frames)
return ret
return ret if output_device is None else ret.to(output_device)
DGLHeteroGraph.select_topk = utils.alias_func(select_topk)
_init_api('dgl.sampling.neighbor', __name__)
"""Sampling utilities"""
from collections.abc import Mapping
import numpy as np
from ..utils import recursive_apply, recursive_apply_pair
from ..base import EID
from .. import backend as F
from .. import transform, utils
def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
"""Find the edges whose IDs in parent graph appeared in exclude_eids.
Note that both arguments are numpy arrays or numpy dicts.
"""
func = lambda x, y: np.isin(x, y).nonzero()[0]
result = recursive_apply_pair(frontier_parent_eids, exclude_eids, func)
return recursive_apply(result, F.zerocopy_from_numpy)
class EidExcluder(object):
"""Class that finds the edges whose IDs in parent graph appeared in exclude_eids.
The edge IDs can be both CPU and GPU tensors.
"""
def __init__(self, exclude_eids):
device = None
if isinstance(exclude_eids, Mapping):
for _, v in exclude_eids.items():
if device is None:
device = F.context(v)
break
else:
device = F.context(exclude_eids)
self._exclude_eids = None
self._filter = None
if device == F.cpu():
# TODO(nv-dlasalle): Once Filter is implemented for the CPU, we
# should just use that irregardless of the device.
self._exclude_eids = (
recursive_apply(exclude_eids, F.zerocopy_to_numpy)
if exclude_eids is not None else None)
else:
self._filter = recursive_apply(exclude_eids, utils.Filter)
def _find_indices(self, parent_eids):
""" Find the set of edge indices to remove.
"""
if self._exclude_eids is not None:
parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy)
return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
else:
assert self._filter is not None
func = lambda x, y: x.find_included_indices(y)
return recursive_apply_pair(self._filter, parent_eids, func)
def __call__(self, frontier):
parent_eids = frontier.edata[EID]
located_eids = self._find_indices(parent_eids)
if not isinstance(located_eids, Mapping):
# (BarclayII) If frontier already has a EID field and located_eids is empty,
# the returned graph will keep EID intact. Otherwise, EID will change
# to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0:
frontier = transform.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else:
# (BarclayII) remove_edges only accepts removing one type of edges,
# so I need to keep track of the edge IDs left one by one.
new_eids = parent_eids.copy()
for k, v in located_eids.items():
if len(v) > 0:
frontier = transform.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids
return frontier
"""Feature storage classes for DataLoading"""
from .. import backend as F
from .base import *
from .numpy import *
if F.get_preferred_backend() == 'pytorch':
from .pytorch_tensor import *
else:
from .tensor import *
"""Base classes and functionalities for feature storages."""
import threading
STORAGE_WRAPPERS = {}
def register_storage_wrapper(type_):
"""Decorator that associates a type to a ``FeatureStorage`` object.
"""
def deco(cls):
STORAGE_WRAPPERS[type_] = cls
return cls
return deco
def wrap_storage(storage):
"""Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper``
decorators.
"""
for type_, storage_cls in STORAGE_WRAPPERS.items():
if isinstance(storage, type_):
return storage_cls(storage)
assert isinstance(storage, FeatureStorage), (
"The frame column must be a tensor or a FeatureStorage object, got {}"
.format(type(storage)))
return storage
class _FuncWrapper(object):
def __init__(self, func):
self.func = func
def __call__(self, buf, *args):
buf[0] = self.func(*args)
class ThreadedFuture(object):
"""Wraps a function into a future asynchronously executed by a Python
``threading.Thread`. The function is being executed upon instantiation of
this object.
"""
def __init__(self, target, args):
self.buf = [None]
thread = threading.Thread(
target=_FuncWrapper(target),
args=[self.buf] + list(args),
daemon=True)
thread.start()
self.thread = thread
def wait(self):
"""Blocks the current thread until the result becomes available and returns it."""
self.thread.join()
return self.buf[0]
class FeatureStorage(object):
"""Feature storage object which should support a fetch() operation. It is the
counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous
graphs where the keys are node/edge types.
"""
def requires_ddp(self):
"""Whether the FeatureStorage requires the DataLoader to set use_ddp.
"""
return False
def fetch(self, indices, device, pin_memory=False):
"""Retrieve the features at the given indices.
If :attr:`indices` is a tensor, this is equivalent to
.. code::
storage[indices]
If :attr:`indices` is a dict of tensor, this is equivalent to
.. code::
{k: storage[k][indices[k]] for k in indices.keys()}
The subclasses can choose to utilize or ignore the flag :attr:`pin_memory`
depending on the underlying framework.
"""
raise NotImplementedError
"""Feature storage for ``numpy.memmap`` object."""
import numpy as np
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper
from .. import backend as F
@register_storage_wrapper(np.memmap)
class NumpyStorage(FeatureStorage):
"""FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object."""
def __init__(self, arr):
self.arr = arr
def _fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument
result = F.zerocopy_from_numpy(self.arr[indices])
result = F.copy_to(result, device)
return result
def fetch(self, indices, device, pin_memory=False):
return ThreadedFuture(target=self._fetch, args=(indices, device, pin_memory))
"""Feature storages for PyTorch tensors."""
import torch
from .base import FeatureStorage, register_storage_wrapper
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory):
result = torch.empty(
indices.shape[0], *feature_shape, dtype=tensor.dtype,
pin_memory=pin_memory)
torch.index_select(tensor, 0, indices, out=result)
result = result.to(device, non_blocking=True)
return result
def _fetch_cuda(indices, tensor, device):
return torch.index_select(tensor, 0, indices).to(device)
@register_storage_wrapper(torch.Tensor)
class TensorStorage(FeatureStorage):
"""Feature storages for slicing a PyTorch tensor."""
def __init__(self, tensor):
self.storage = tensor
self.feature_shape = tensor.shape[1:]
self.is_cuda = (tensor.device.type == 'cuda')
def fetch(self, indices, device, pin_memory=False):
device = torch.device(device)
if not self.is_cuda:
# CPU to CPU or CUDA - use pin_memory and async transfer if possible
return _fetch_cpu(indices, self.storage, self.feature_shape, device, pin_memory)
else:
# CUDA to CUDA or CPU
return _fetch_cuda(indices, self.storage, device)
"""Feature storages for tensors across different frameworks."""
from .base import FeatureStorage
from .. import backend as F
from ..utils import recursive_apply_pair
def _fetch(indices, tensor, device):
return F.copy_to(F.gather_row(tensor, indices), device)
class TensorStorage(FeatureStorage):
"""FeatureStorage that synchronously slices features from a tensor and transfers
it to the given device.
"""
def __init__(self, tensor):
self.storage = tensor
def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument
return recursive_apply_pair(indices, self.storage, _fetch, device)
......@@ -13,11 +13,12 @@ from . import heterograph_index
from . import ndarray as nd
from .heterograph import DGLHeteroGraph
from . import utils
from .utils import recursive_apply
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph',
'in_subgraph', 'out_subgraph', 'khop_in_subgraph', 'khop_out_subgraph']
def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None):
"""Return a subgraph induced on the given nodes.
A node-induced subgraph is a graph with edges whose endpoints are both in the
......@@ -53,6 +54,8 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the specified nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -150,11 +153,13 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
# bug in #1453.
if not relabel_nodes:
induced_nodes = None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **deprecated_kwargs):
def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, output_device=None,
**deprecated_kwargs):
"""Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph using the given
......@@ -190,6 +195,8 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **depreca
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -301,11 +308,12 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **depreca
induced_edges.append(_process_edges(cetype, eids))
sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None):
"""Return the subgraph induced on the inbound edges of all the edge types of the
given nodes.
......@@ -340,6 +348,8 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -426,11 +436,12 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph)
def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None):
"""Return the subgraph induced on the outbound edges of all the edge types of the
given nodes.
......@@ -465,6 +476,8 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -551,11 +564,12 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
return subg if output_device is None else subg.to(output_device)
DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph)
def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None):
"""Return the subgraph induced by k-hop in-neighborhood of the specified node(s).
We can expand a set of nodes by including the predecessors of them. From a
......@@ -594,6 +608,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -693,6 +709,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True)
sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids)
if output_device is not None:
sub_g = sub_g.to(output_device)
if relabel_nodes:
if is_mapping:
seed_inverse_indices = dict()
......@@ -702,13 +720,16 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
else:
seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
if output_device is not None:
seed_inverse_indices = recursive_apply(
seed_inverse_indices, lambda x: F.copy_to(x, output_device))
return sub_g, seed_inverse_indices
else:
return sub_g
DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph)
def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None):
"""Return the subgraph induced by k-hop out-neighborhood of the specified node(s).
We can expand a set of nodes by including the successors of them. From a
......@@ -747,6 +768,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -847,6 +870,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True)
sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids)
if output_device is not None:
sub_g = sub_g.to(output_device)
if relabel_nodes:
if is_mapping:
seed_inverse_indices = dict()
......@@ -856,13 +881,16 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
else:
seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
if output_device is not None:
seed_inverse_indices = recursive_apply(
seed_inverse_indices, lambda x: F.copy_to(x, output_device))
return sub_g, seed_inverse_indices
else:
return sub_g
DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph)
def node_type_subgraph(graph, ntypes):
def node_type_subgraph(graph, ntypes, output_device=None):
"""Return the subgraph induced on given node types.
A node-type-induced subgraph contains all the nodes of the given subset of
......@@ -877,6 +905,8 @@ def node_type_subgraph(graph, ntypes):
The graph to extract subgraphs from.
ntypes : list[str]
The type names of the nodes in the subgraph.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -935,11 +965,11 @@ def node_type_subgraph(graph, ntypes):
etypes.append(graph.canonical_etypes[etid])
if len(etypes) == 0:
raise DGLError('There are no edges among nodes of the specified types.')
return edge_type_subgraph(graph, etypes)
return edge_type_subgraph(graph, etypes, output_device=output_device)
DGLHeteroGraph.node_type_subgraph = utils.alias_func(node_type_subgraph)
def edge_type_subgraph(graph, etypes):
def edge_type_subgraph(graph, etypes, output_device=None):
"""Return the subgraph induced on given edge types.
An edge-type-induced subgraph contains all the edges of the given subset of
......@@ -960,6 +990,8 @@ def edge_type_subgraph(graph, etypes):
* ``(str, str, str)`` for source node type, edge type and destination node type.
* or one ``str`` for the edge type name if the name can uniquely identify a
triplet format in the graph.
output_device : Framework-specific device context object, optional
The output device. Default is the same as the input graph.
Returns
-------
......@@ -1029,7 +1061,7 @@ def edge_type_subgraph(graph, etypes):
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
return hg if output_device is None else hg.to(output_device)
DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph)
......
......@@ -4,3 +4,4 @@ from .data import *
from .checks import *
from .shared_mem import *
from .filter import *
from .exception import *
"""Exception wrapper classes to properly display exceptions under multithreading or
multiprocessing.
"""
import sys
import traceback
# The following code is borrowed from PyTorch. Basically when a subprocess or thread
# throws an exception, you will need to wrap the exception with ExceptionWrapper class
# and put it in the queue you are normally retrieving from.
# NOTE [ Python Traceback Reference Cycle Problem ]
#
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
# which is the traceback, because otherwise you will run into the traceback
# reference cycle problem, i.e., the traceback holding reference to the frame,
# and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback.
class KeyErrorMessage(str):
r"""str subclass that returns itself in repr"""
def __repr__(self): # pylint: disable=invalid-repr-returned
return self
class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info=None, where="in background"):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where
def reraise(self):
r"""Reraises the wrapped exception in the current thread"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = "Caught {} {}.\nOriginal {}".format(
self.exc_type.__name__, self.where, self.exc_msg)
if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
# (https://bugs.python.org/issue2651), so we work around it.
msg = KeyErrorMessage(msg)
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
raise self.exc_type(message=msg)
try:
exception = self.exc_type(msg)
except TypeError:
# If the exception takes multiple arguments, don't try to
# instantiate since we don't know how to
raise RuntimeError(msg) from None
raise exception
"""Internal utilities."""
from __future__ import absolute_import, division
from collections.abc import Mapping, Iterable
from collections.abc import Mapping, Iterable, Sequence
from collections import defaultdict
from functools import wraps
import numpy as np
......@@ -910,4 +910,31 @@ def alias_func(func):
_fn.__doc__ = """Alias of :func:`dgl.{}`.""".format(func.__name__)
return _fn
def recursive_apply(data, fn, *args, **kwargs):
"""Recursively apply a function to every element in a container.
"""
if isinstance(data, str): # str is a Sequence
return fn(data, *args, **kwargs)
elif isinstance(data, Mapping):
return {k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()}
elif isinstance(data, Sequence):
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
else:
return fn(data, *args, **kwargs)
def recursive_apply_pair(data1, data2, fn, *args, **kwargs):
"""Recursively apply a function to every pair of elements in two containers with the
same nested structure.
"""
if isinstance(data1, str) or isinstance(data2, str):
return fn(data1, data2, *args, **kwargs)
elif isinstance(data1, Mapping) and isinstance(data2, Mapping):
return {
k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs)
for k in data1.keys()}
elif isinstance(data1, Sequence) and isinstance(data2, Sequence):
return [recursive_apply_pair(x, y, fn, *args, **kwargs) for x, y in zip(data1, data2)]
else:
return fn(data1, data2, *args, **kwargs)
_init_api("dgl.utils.internal")
......@@ -6,6 +6,7 @@ from collections.abc import MutableMapping
from .base import ALL, DGLError
from . import backend as F
from .frame import LazyFeature
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
......@@ -66,7 +67,9 @@ class HeteroNodeDataView(MutableMapping):
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
if isinstance(self._ntype, list):
if isinstance(val, LazyFeature):
self._graph._node_frames[self._ntid][key] = val
elif isinstance(self._ntype, list):
assert isinstance(val, dict), \
'Current HeteroNodeDataView has multiple node types, ' \
'please passing the node type and the corresponding data through a dict.'
......@@ -89,36 +92,33 @@ class HeteroNodeDataView(MutableMapping):
else:
self._graph._pop_n_repr(self._ntid, key)
def _transpose(self, as_dict=False):
if isinstance(self._ntype, list):
ret = defaultdict(dict)
for (i, ntype) in enumerate(self._ntype):
data = self._graph._get_n_repr(self._ntid[i], self._nodes)
for key in self._graph._node_frames[self._ntid[i]]:
ret[key][ntype] = data[key]
else:
ret = self._graph._get_n_repr(self._ntid, self._nodes)
if as_dict:
ret = {key: ret[key] for key in self._graph._node_frames[self._ntid]}
return ret
def __len__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not support len().'
return len(self._graph._node_frames[self._ntid])
return len(self._transpose())
def __iter__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not be iterated.'
return iter(self._graph._node_frames[self._ntid])
return iter(self._transpose())
def keys(self):
return self._graph._node_frames[self._ntid].keys()
return self._transpose().keys()
def values(self):
return self._graph._node_frames[self._ntid].values()
return self._transpose().values()
def __repr__(self):
if isinstance(self._ntype, list):
ret = defaultdict(dict)
for (i, ntype) in enumerate(self._ntype):
data = self._graph._get_n_repr(self._ntid[i], self._nodes)
for key in self._graph._node_frames[self._ntid[i]]:
ret[key][ntype] = data[key]
return repr(ret)
else:
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._ntid]})
return repr(self._transpose(as_dict=True))
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
......@@ -181,7 +181,9 @@ class HeteroEdgeDataView(MutableMapping):
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
if isinstance(self._etype, list):
if isinstance(val, LazyFeature):
self._graph._edge_frames[self._etid][key] = val
elif isinstance(self._etype, list):
assert isinstance(val, dict), \
'Current HeteroEdgeDataView has multiple edge types, ' \
'please pass the edge type and the corresponding data through a dict.'
......@@ -204,33 +206,30 @@ class HeteroEdgeDataView(MutableMapping):
else:
self._graph._pop_e_repr(self._etid, key)
def _transpose(self, as_dict=False):
if isinstance(self._etype, list):
ret = defaultdict(dict)
for (i, etype) in enumerate(self._etype):
data = self._graph._get_e_repr(self._etid[i], self._edges)
for key in self._graph._edge_frames[self._etid[i]]:
ret[key][etype] = data[key]
else:
ret = self._graph._get_e_repr(self._etid, self._edges)
if as_dict:
ret = {key: ret[key] for key in self._graph._edge_frames[self._etid]}
return ret
def __len__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not support len().'
return len(self._graph._edge_frames[self._etid])
return len(self._transpose())
def __iter__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not be iterated.'
return iter(self._graph._edge_frames[self._etid])
return iter(self._transpose())
def keys(self):
return self._graph._edge_frames[self._etid].keys()
return self._transpose().keys()
def values(self):
return self._graph._edge_frames[self._etid].values()
return self._transpose().values()
def __repr__(self):
if isinstance(self._etype, list):
ret = defaultdict(dict)
for (i, etype) in enumerate(self._etype):
data = self._graph._get_e_repr(self._etid[i], self._edges)
for key in self._graph._edge_frames[self._etid[i]]:
ret[key][etype] = data[key]
return repr(ret)
else:
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._etid]})
return repr(self._transpose(as_dict=True))
......@@ -594,3 +594,6 @@ def test_khop_out_subgraph(idtype):
assert edge_set == {(0, 1)}
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype))
assert F.array_equal(F.astype(inv['game'], idtype), F.tensor([0], idtype))
if __name__ == '__main__':
test_khop_out_subgraph(F.int64)
......@@ -2308,3 +2308,4 @@ def test_module_add_edge(idtype):
if __name__ == '__main__':
test_partition_with_halo()
test_module_heat_kernel(F.int32)
......@@ -146,14 +146,14 @@ def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, g
num_negs = 5
sampler = dgl.dataloading.MultiLayerNeighborSampler([5,10])
negative_sampler=dgl.dataloading.negative_sampler.Uniform(num_negs)
dataloader = dgl.dataloading.EdgeDataLoader(dist_graph,
train_eid,
sampler,
batch_size=batch_size,
negative_sampler=negative_sampler,
shuffle=True,
drop_last=False,
num_workers=num_workers)
dataloader = dgl.dataloading.DistEdgeDataLoader(dist_graph,
train_eid,
sampler,
batch_size=batch_size,
negative_sampler=negative_sampler,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for _ in range(2):
for _, (_, pos_graph, neg_graph, blocks) in zip(range(0, num_edges_to_sample, batch_size), dataloader):
block = blocks[-1]
......@@ -288,7 +288,7 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DistNodeDataLoader(
dist_graph,
train_nid,
sampler,
......@@ -339,7 +339,7 @@ def start_edge_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
dataloader = dgl.dataloading.EdgeDataLoader(
dataloader = dgl.dataloading.DistEdgeDataLoader(
dist_graph,
train_eid,
sampler,
......
......@@ -10,193 +10,6 @@ from collections.abc import Iterator
from itertools import product
import pytest
def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
seeds = defaultdict(list)
for item in dl:
if mode == 'node':
input_nodes, output_nodes, blocks = item
elif mode == 'edge':
input_nodes, pair_graph, blocks = item
output_nodes = pair_graph.ndata[dgl.NID]
elif mode == 'link':
input_nodes, pair_graph, neg_graph, blocks = item
output_nodes = pair_graph.ndata[dgl.NID]
for ntype in pair_graph.ntypes:
assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID])
if len(g.ntypes) > 1:
for ntype in g.ntypes:
assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID])
assert F.array_equal(output_nodes[ntype], blocks[-1].dstnodes[ntype].data[dgl.NID])
else:
assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID])
assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID])
prev_dst = {ntype: None for ntype in g.ntypes}
for block in blocks:
for canonical_etype in block.canonical_etypes:
utype, etype, vtype = canonical_etype
uu, vv = block.all_edges(order='eid', etype=canonical_etype)
src = block.srcnodes[utype].data[dgl.NID]
dst = block.dstnodes[vtype].data[dgl.NID]
assert F.array_equal(
block.srcnodes[utype].data['feat'], g.nodes[utype].data['feat'][src])
assert F.array_equal(
block.dstnodes[vtype].data['feat'], g.nodes[vtype].data['feat'][dst])
if prev_dst[utype] is not None:
assert F.array_equal(src, prev_dst[utype])
u = src[uu]
v = dst[vv]
assert F.asnumpy(g.has_edges_between(u, v, etype=canonical_etype)).all()
eid = block.edges[canonical_etype].data[dgl.EID]
assert F.array_equal(
block.edges[canonical_etype].data['feat'],
g.edges[canonical_etype].data['feat'][eid])
ufound, vfound = g.find_edges(eid, etype=canonical_etype)
assert F.array_equal(ufound, u)
assert F.array_equal(vfound, v)
for ntype in block.dsttypes:
src = block.srcnodes[ntype].data[dgl.NID]
dst = block.dstnodes[ntype].data[dgl.NID]
assert F.array_equal(src[:block.number_of_dst_nodes(ntype)], dst)
prev_dst[ntype] = dst
if mode == 'node':
for ntype in blocks[-1].dsttypes:
seeds[ntype].append(blocks[-1].dstnodes[ntype].data[dgl.NID])
elif mode == 'edge' or mode == 'link':
for etype in pair_graph.canonical_etypes:
seeds[etype].append(pair_graph.edges[etype].data[dgl.EID])
# Check if all nodes/edges are iterated
seeds = {k: F.cat(v, 0) for k, v in seeds.items()}
for k, v in seeds.items():
if k in nids:
seed_set = set(F.asnumpy(nids[k]))
elif isinstance(k, tuple) and k[1] in nids:
seed_set = set(F.asnumpy(nids[k[1]]))
else:
continue
v_set = set(F.asnumpy(v))
assert v_set == seed_set
def test_neighbor_sampler_dataloader():
g = dgl.heterograph({('user', 'follow', 'user'): ([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])},
{'user': 6}).long()
g = dgl.to_bidirected(g).to(F.ctx())
g.ndata['feat'] = F.randn((6, 8))
g.edata['feat'] = F.randn((10, 4))
reverse_eids = F.tensor([5, 6, 7, 8, 9, 0, 1, 2, 3, 4], dtype=F.int64)
g_sampler1 = dgl.dataloading.MultiLayerNeighborSampler([2, 2], return_eids=True)
g_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
hg = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
}).long().to(F.ctx())
for ntype in hg.ntypes:
hg.nodes[ntype].data['feat'] = F.randn((hg.number_of_nodes(ntype), 8))
for etype in hg.canonical_etypes:
hg.edges[etype].data['feat'] = F.randn((hg.number_of_edges(etype), 4))
hg_sampler1 = dgl.dataloading.MultiLayerNeighborSampler(
[{'play': 1, 'played-by': 1, 'follow': 2, 'followed-by': 1}] * 2, return_eids=True)
hg_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
reverse_etypes = {'follow': 'followed-by', 'followed-by': 'follow', 'play': 'played-by', 'played-by': 'play'}
collators = []
graphs = []
nids = []
modes = []
for seeds, sampler in product(
[F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)],
[g_sampler1, g_sampler2]):
collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler))
graphs.append(g)
nids.append({'user': seeds})
modes.append('node')
collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='self'))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('link')
collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='self', negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('link')
collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(g)
nids.append({'follow': seeds})
modes.append('link')
for seeds, sampler in product(
[{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)},
{'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}],
[hg_sampler1, hg_sampler2]):
collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler))
graphs.append(hg)
nids.append(seeds)
modes.append('node')
for seeds, sampler in product(
[{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)},
{'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}],
[hg_sampler1, hg_sampler2]):
collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler))
graphs.append(hg)
nids.append(seeds)
modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes))
graphs.append(hg)
nids.append(seeds)
modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(hg)
nids.append(seeds)
modes.append('link')
collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
graphs.append(hg)
nids.append(seeds)
modes.append('link')
for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
dl = DataLoader(
collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
assert isinstance(iter(dl), Iterator)
_check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
def test_graph_dataloader():
batch_size = 16
......@@ -213,15 +26,12 @@ def test_graph_dataloader():
def test_cluster_gcn(num_workers):
dataset = dgl.data.CoraFullDataset()
g = dataset[0]
sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=True)
dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers)
for sg in dataloader:
assert sg.batch_size == 4
sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=False) # use cache
dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers)
for sg in dataloader:
assert sg.batch_size == 4
sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers)
assert len(dataloader) == 25
for i, sg in enumerate(dataloader):
pass
@pytest.mark.parametrize('num_workers', [0, 4])
def test_shadow(num_workers):
......@@ -230,7 +40,7 @@ def test_shadow(num_workers):
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()), sampler,
batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers)
for i, (input_nodes, output_nodes, (subgraph,)) in enumerate(dataloader):
for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes])
......@@ -288,37 +98,25 @@ def _check_device(data):
else:
assert data.device == F.ctx()
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2', 'shadow'])
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
def test_node_dataloader(sampler_name):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
for load_input, load_output in [(None, None), ({'feat': g1.ndata['feat']}, {'label': g1.ndata['label']})]:
for async_load in [False, True]:
for num_workers in [0, 1, 2]:
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(),
load_input=load_input,
load_output=load_output,
async_load=async_load,
batch_size=g1.num_nodes(),
num_workers=num_workers)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
if load_input:
_check_device(blocks[0].srcdata['feat'])
OPS.copy_u_sum(blocks[0], blocks[0].srcdata['feat'])
if load_output:
_check_device(blocks[-1].dstdata['label'])
OPS.copy_u_sum(blocks[-1], blocks[-1].dstdata['label'])
for num_workers in [0, 1, 2]:
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(),
batch_size=g1.num_nodes(),
num_workers=num_workers)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
......@@ -332,30 +130,19 @@ def test_node_dataloader(sampler_name):
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name]
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
for async_load in [False, True]:
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), async_load=async_load, batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
status = False
try:
dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), load_input={'feat': g1.ndata['feat']}, batch_size=batch_size)
except dgl.DGLError:
status = True
assert status
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'shadow'])
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
@pytest.mark.parametrize('neg_sampler', [
dgl.dataloading.negative_sampler.Uniform(2),
dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
......@@ -366,8 +153,7 @@ def test_edge_dataloader(sampler_name, neg_sampler):
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
# no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
......@@ -399,7 +185,7 @@ def test_edge_dataloader(sampler_name, neg_sampler):
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name]
}[sampler_name]
# no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
......@@ -424,11 +210,10 @@ def test_edge_dataloader(sampler_name, neg_sampler):
_check_device(blocks)
if __name__ == '__main__':
test_neighbor_sampler_dataloader()
test_graph_dataloader()
test_cluster_gcn(0)
test_neighbor_nonuniform(0)
for sampler in ['full', 'neighbor', 'shadow']:
for sampler in ['full', 'neighbor']:
test_node_dataloader(sampler)
for neg_sampler in [
dgl.dataloading.negative_sampler.Uniform(2),
......
......@@ -265,7 +265,8 @@ Pytorch's `DistributedDataParallel`.
Distributed mini-batch sampler
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We can use the same `NodeDataLoader` to create a distributed mini-batch sampler for
We can use the same :class:`~dgl.dataloading.pytorch.DistNodeDataLoader`, the distributed counterpart
of :class:`~dgl.dataloading.pytorch.NodeDataLoader`, to create a distributed mini-batch sampler for
node classification.
......@@ -274,10 +275,10 @@ node classification.
.. code-block:: python
sampler = dgl.dataloading.MultiLayerNeighborSampler([25,10])
train_dataloader = dgl.dataloading.NodeDataLoader(
train_dataloader = dgl.dataloading.DistNodeDataLoader(
g, train_nid, sampler, batch_size=1024,
shuffle=True, drop_last=False)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DistNodeDataLoader(
g, valid_nid, sampler, batch_size=1024,
shuffle=False, drop_last=False)
......@@ -432,4 +433,4 @@ If we split the graph into four partitions as demonstrated at the beginning of t
ip_addr3
ip_addr4
'''
\ No newline at end of file
'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment