Unverified Commit e06e63d5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Enable UVA sampling with CPU indices (#3892)

* enable UVA sampling with CPU indices

* add docs

* add more docs

* lint

* fix

* fix

* better error message

* use mp.Barrier instead of queues

* revert

* revert

* oops

* revert dgl.multiprocessing.spawn

* Update pytorch.py
parent 0d878ff8
...@@ -12,3 +12,11 @@ additional documentation. ...@@ -12,3 +12,11 @@ additional documentation.
In addition, if your backend is PyTorch, this module will also be compatible with In addition, if your backend is PyTorch, this module will also be compatible with
:mod:`torch.multiprocessing` module. :mod:`torch.multiprocessing` module.
.. currentmodule:: dgl.multiprocessing.pytorch
.. autosummary::
:toctree: ../../generated/
spawn
call_once_and_share
shared_tensor
...@@ -6,45 +6,14 @@ import torch.distributed.optim ...@@ -6,45 +6,14 @@ import torch.distributed.optim
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.utils import pin_memory_inplace, unpin_memory_inplace, \ from dgl.utils import pin_memory_inplace, unpin_memory_inplace
gather_pinned_tensor_rows, create_shared_mem_array, get_shared_mem_array from dgl.multiprocessing import shared_tensor
import time import time
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm import tqdm
def shared_tensor(*shape, device, name, dtype=torch.float32):
""" Create a tensor in shared memroy, pinned in each process's CUDA
context.
Parameters
----------
shape : int...
A sequence of integers describing the shape of the new tensor.
device : context
The device of the result tensor.
name : string
The name of the shared allocation.
dtype : dtype, optional
The datatype of the allocation. Default: torch.float32
Returns
-------
Tensor :
The shared tensor.
"""
rank = dist.get_rank()
if rank == 0:
y = create_shared_mem_array(
name, shape, dtype)
dist.barrier()
if rank != 0:
y = get_shared_mem_array(name, shape, dtype)
pin_memory_inplace(y)
return y
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
...@@ -105,9 +74,8 @@ class SAGE(nn.Module): ...@@ -105,9 +74,8 @@ class SAGE(nn.Module):
# shared output tensor 'y' in host memory, pin it to allow UVA # shared output tensor 'y' in host memory, pin it to allow UVA
# access from each GPU during forward propagation. # access from each GPU during forward propagation.
y = shared_tensor( y = shared_tensor(
g.num_nodes(), (g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes))
self.n_hidden if l != len(self.layers) - 1 else self.n_classes, pin_memory_inplace(y)
device='cpu', name='layer_{}_output'.format(l))
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \ for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
if dist.get_rank() == 0 else dataloader: if dist.get_rank() == 0 else dataloader:
......
import torch import torch
import argparse import argparse
import dgl import dgl
import dgl.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import os import os
import random import random
......
...@@ -4,8 +4,8 @@ import torch.nn.functional as F ...@@ -4,8 +4,8 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import dgl.multiprocessing as mp import torch.multiprocessing as mp
from dgl.multiprocessing import Queue from torch.multiprocessing import Queue
def init_emb2pos_index(walk_length, window_size, batch_size): def init_emb2pos_index(walk_length, window_size, batch_size):
......
import torch import torch
import argparse import argparse
import dgl import dgl
import dgl.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import os import os
import random import random
......
...@@ -4,8 +4,8 @@ import torch.nn.functional as F ...@@ -4,8 +4,8 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import dgl.multiprocessing as mp import torch.multiprocessing as mp
from dgl.multiprocessing import Queue from torch.multiprocessing import Queue
def init_emb2neg_index(negative, batch_size): def init_emb2neg_index(negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings '''select embedding of negative nodes from a batch of node embeddings
......
...@@ -6,8 +6,8 @@ import argparse ...@@ -6,8 +6,8 @@ import argparse
import gc import gc
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp
import dgl import dgl
import torch.multiprocessing as mp
from torchmetrics.functional import accuracy from torchmetrics.functional import accuracy
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
......
...@@ -4,7 +4,6 @@ from queue import Queue, Empty, Full ...@@ -4,7 +4,6 @@ from queue import Queue, Empty, Full
import itertools import itertools
import threading import threading
from distutils.version import LooseVersion from distutils.version import LooseVersion
import random
import math import math
import inspect import inspect
import re import re
...@@ -18,15 +17,15 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -18,15 +17,15 @@ from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning from ..base import NID, EID, dgl_warning
from ..batch import batch as batch_graphs from ..batch import batch as batch_graphs
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd
from ..utils import ( from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array, context_of, dtype_of) context_of, dtype_of)
from ..frame import LazyFeature from ..frame import LazyFeature
from ..storages import wrap_storage from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler from .base import BlockSampler, as_edge_prediction_sampler
from .. import backend as F from .. import backend as F
from ..distributed import DistGraph from ..distributed import DistGraph
from ..multiprocessing import call_once_and_share
PYTHON_EXIT_STATUS = False PYTHON_EXIT_STATUS = False
def _set_python_exit_flag(): def _set_python_exit_flag():
...@@ -87,12 +86,19 @@ class _TensorizedDatasetIter(object): ...@@ -87,12 +86,19 @@ class _TensorizedDatasetIter(object):
def _get_id_tensor_from_mapping(indices, device, keys): def _get_id_tensor_from_mapping(indices, device, keys):
dtype = dtype_of(indices) dtype = dtype_of(indices)
lengths = torch.tensor( id_tensor = torch.empty(
[(indices[k].shape[0] if k in indices else 0) for k in keys], sum(v.shape[0] for v in indices.values()), 2, dtype=dtype, device=device)
dtype=dtype, device=device)
type_ids = torch.arange(len(keys), dtype=dtype, device=device).repeat_interleave(lengths) offset = 0
all_indices = torch.cat([indices[k] for k in keys if k in indices]) for i, k in enumerate(keys):
return torch.stack([type_ids, all_indices], 1) if k not in indices:
continue
index = indices[k]
length = index.shape[0]
id_tensor[offset:offset+length, 0] = i
id_tensor[offset:offset+length, 1] = index
offset += length
return id_tensor
def _divide_by_worker(dataset, batch_size, drop_last): def _divide_by_worker(dataset, batch_size, drop_last):
...@@ -115,7 +121,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -115,7 +121,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
""" """
def __init__(self, indices, batch_size, drop_last): def __init__(self, indices, batch_size, drop_last):
name, _ = _generate_shared_mem_name_id()
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
...@@ -128,12 +133,10 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -128,12 +133,10 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that # Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices # the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch. # would not be duplicated every epoch.
self._indices = create_shared_mem_array(name, (self._id_tensor.shape[0],), torch.int64) self._indices = torch.empty(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self._indices[:] = torch.arange(self._id_tensor.shape[0]) self._indices[:] = torch.arange(self._id_tensor.shape[0])
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self.shared_mem_name = name
self.shared_mem_size = self._indices.shape[0]
def shuffle(self): def shuffle(self):
"""Shuffle the dataset.""" """Shuffle the dataset."""
...@@ -150,17 +153,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -150,17 +153,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
num_samples = self._id_tensor.shape[0] num_samples = self._id_tensor.shape[0]
return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size
def _get_shared_mem_name(id_):
return f'ddp_{id_}'
def _generate_shared_mem_name_id():
for _ in range(3): # 3 trials
id_ = random.getrandbits(32)
name = _get_shared_mem_name(id_)
if not nd.exist_shared_mem_array(name):
return name, id_
raise DGLError('Unable to generate a shared memory array')
class DDPTensorizedDataset(torch.utils.data.IterableDataset): class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
...@@ -197,33 +189,22 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -197,33 +189,22 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
self._id_tensor = _get_id_tensor_from_mapping( self._id_tensor = call_once_and_share(
indices, self._device, self._mapping_keys) lambda: _get_id_tensor_from_mapping(indices, self._device, self._mapping_keys),
(self.num_indices, 2), dtype_of(indices))
else: else:
self._id_tensor = indices self._id_tensor = indices
self._device = self._id_tensor.device self._device = self._id_tensor.device
if self.rank == 0: self._indices = call_once_and_share(
name, id_ = _generate_shared_mem_name_id() self._create_shared_indices, (self.shared_mem_size,), torch.int64)
self._indices = create_shared_mem_array(
name, (self.shared_mem_size,), torch.int64)
self._indices[:self._id_tensor.shape[0]] = torch.arange(self._id_tensor.shape[0])
meta_info = torch.LongTensor([id_, self._indices.shape[0]])
else:
meta_info = torch.LongTensor([0, 0])
if dist.get_backend() == 'nccl':
# Use default CUDA device; PyTorch DDP required the users to set the CUDA
# device for each process themselves so calling .cuda() should be safe.
meta_info = meta_info.cuda()
dist.broadcast(meta_info, src=0)
if self.rank != 0: def _create_shared_indices(self):
id_, num_samples = meta_info.tolist() indices = torch.empty(self.shared_mem_size, dtype=torch.int64)
name = _get_shared_mem_name(id_) num_ids = self._id_tensor.shape[0]
indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64) indices[:num_ids] = torch.arange(num_ids)
self._indices = indices_shared indices[num_ids:] = torch.arange(self.shared_mem_size - num_ids)
self.shared_mem_name = name return indices
def shuffle(self): def shuffle(self):
"""Shuffles the dataset.""" """Shuffles the dataset."""
...@@ -525,11 +506,16 @@ class CollateWrapper(object): ...@@ -525,11 +506,16 @@ class CollateWrapper(object):
"""Wraps a collate function with :func:`remove_parent_storage_columns` for serializing """Wraps a collate function with :func:`remove_parent_storage_columns` for serializing
from PyTorch DataLoader workers. from PyTorch DataLoader workers.
""" """
def __init__(self, sample_func, g): def __init__(self, sample_func, g, use_uva, device):
self.sample_func = sample_func self.sample_func = sample_func
self.g = g self.g = g
self.use_uva = use_uva
self.device = device
def __call__(self, items): def __call__(self, items):
if self.use_uva:
# Only copy the indices to the given device if in UVA mode.
items = recursive_apply(items, lambda x: x.to(self.device))
batch = self.sample_func(self.g, items) batch = self.sample_func(self.g, items)
return recursive_apply(batch, remove_parent_storage_columns, self.g) return recursive_apply(batch, remove_parent_storage_columns, self.g)
...@@ -771,10 +757,6 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -771,10 +757,6 @@ class DataLoader(torch.utils.data.DataLoader):
if use_uva: if use_uva:
if self.graph.device.type != 'cpu': if self.graph.device.type != 'cpu':
raise ValueError('Graph must be on CPU if UVA sampling is enabled.') raise ValueError('Graph must be on CPU if UVA sampling is enabled.')
if indices_device != self.device:
raise ValueError(
f'Indices must be on the same device as the device argument '
f'({self.device})')
if num_workers > 0: if num_workers > 0:
raise ValueError('num_workers must be 0 if UVA sampling is enabled.') raise ValueError('num_workers must be 0 if UVA sampling is enabled.')
...@@ -845,7 +827,8 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -845,7 +827,8 @@ class DataLoader(torch.utils.data.DataLoader):
super().__init__( super().__init__(
self.dataset, self.dataset,
collate_fn=CollateWrapper(self.graph_sampler.sample, graph), collate_fn=CollateWrapper(
self.graph_sampler.sample, graph, self.use_uva, self.device),
batch_size=None, batch_size=None,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
**kwargs) **kwargs)
......
...@@ -9,8 +9,8 @@ from .. import backend as F ...@@ -9,8 +9,8 @@ from .. import backend as F
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == 'pytorch':
# Wrap around torch.multiprocessing... # Wrap around torch.multiprocessing...
from torch.multiprocessing import * from torch.multiprocessing import *
# ... and override the Process initializer # ... and override the Process initializer and spawn function.
from .pytorch import Process from .pytorch import *
else: else:
# Just import multiprocessing module. # Just import multiprocessing module.
from multiprocessing import * # pylint: disable=redefined-builtin from multiprocessing import * # pylint: disable=redefined-builtin
"""PyTorch multiprocessing wrapper.""" """PyTorch multiprocessing wrapper."""
from functools import wraps from functools import wraps
import random
import traceback import traceback
from _thread import start_new_thread from _thread import start_new_thread
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from ..utils import create_shared_mem_array, get_shared_mem_array
def thread_wrapped_func(func): def thread_wrapped_func(func):
""" """
Wraps a process entry point to make it work with OpenMP. Wraps a process entry point to make it work with OpenMP.
...@@ -35,3 +39,70 @@ class Process(mp.Process): ...@@ -35,3 +39,70 @@ class Process(mp.Process):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
target = thread_wrapped_func(target) target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon) super().__init__(group, target, name, args, kwargs, daemon=daemon)
def _get_shared_mem_name(id_):
return "shared" + str(id_)
def call_once_and_share(func, shape, dtype, rank=0):
"""Invoke the function in a single process of the process group spawned by
:func:`spawn`, and share the result to other processes.
Requires the subprocesses to be spawned with :func:`dgl.multiprocessing.pytorch.spawn`.
Parameters
----------
func : callable
Any callable that accepts no arguments and returns an arbitrary object.
shape : tuple[int]
The shape of the shared tensor. Must match the output of :attr:`func`.
dtype : torch.dtype
The data type of the shared tensor. Must match the output of :attr:`func`.
rank : int, optional
The process ID to actually execute the function.
"""
current_rank = torch.distributed.get_rank()
dist_buf = torch.LongTensor([1])
if torch.distributed.get_backend() == 'nccl':
# Use .cuda() to transfer it to the correct device. Should be OK since
# PyTorch recommends the users to call set_device() after getting inside
# torch.multiprocessing.spawn()
dist_buf = dist_buf.cuda()
# Process with the given rank creates and populates the shared memory array.
if current_rank == rank:
id_ = random.getrandbits(32)
name = _get_shared_mem_name(id_)
result = create_shared_mem_array(name, shape, dtype)
result[:] = func()
dist_buf[0] = id_
# Broadcasts the name of the shared array to other processes.
torch.distributed.broadcast(dist_buf, rank)
# If no exceptions, other processes open the same shared memory object.
if current_rank != rank:
id_ = dist_buf.item()
name = _get_shared_mem_name(id_)
result = get_shared_mem_array(name, shape, dtype)
return result
def shared_tensor(shape, dtype=torch.float32):
"""Create a tensor in shared memory accessible by all processes within the same
``torch.distsributed`` process group.
The content is uninitialized.
Parameters
----------
shape : tuple[int]
The shape of the tensor.
dtype : torch.dtype, optional
The dtype of the tensor.
Returns
-------
Tensor
The shared tensor.
"""
return call_once_and_share(lambda: torch.empty(*shape, dtype=dtype), shape, dtype)
...@@ -51,8 +51,8 @@ def randn(shape): ...@@ -51,8 +51,8 @@ def randn(shape):
def tensor(data, dtype=None): def tensor(data, dtype=None):
return copy_to(_tensor(data, dtype), _default_context) return copy_to(_tensor(data, dtype), _default_context)
def arange(start, stop, dtype=int64): def arange(start, stop, dtype=int64, ctx=None):
return copy_to(_arange(start, stop, dtype), _default_context) return _arange(start, stop, dtype, ctx if ctx is not None else _default_context)
def full(shape, fill_value, dtype, ctx=_default_context): def full(shape, fill_value, dtype, ctx=_default_context):
return _full(shape, fill_value, dtype, ctx) return _full(shape, fill_value, dtype, ctx)
......
...@@ -130,14 +130,27 @@ def _check_device(data): ...@@ -130,14 +130,27 @@ def _check_device(data):
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2']) @pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [False, True]) @pytest.mark.parametrize('pin_graph', [None, 'cuda_indices', 'cpu_indices'])
def test_node_dataloader(idtype, sampler_name, pin_graph): def test_node_dataloader(idtype, sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype) g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
if F.ctx() != F.cpu() and pin_graph:
g1.create_formats_()
g1.pin_memory_()
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu()) g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu()) g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
indices = F.arange(0, g1.num_nodes(), idtype)
if F.ctx() != F.cpu():
if pin_graph:
g1.create_formats_()
g1.pin_memory_()
if pin_graph == 'cpu_indices':
indices = F.arange(0, g1.num_nodes(), idtype, F.cpu())
elif pin_graph == 'cuda_indices':
if F._default_context_str == 'gpu':
indices = F.arange(0, g1.num_nodes(), idtype, F.cuda())
else:
return # skip
else:
g1 = g1.to('cuda')
use_uva = pin_graph is not None and F.ctx() != F.cpu()
for num_workers in [0, 1, 2]: for num_workers in [0, 1, 2]:
sampler = { sampler = {
...@@ -145,9 +158,10 @@ def test_node_dataloader(idtype, sampler_name, pin_graph): ...@@ -145,9 +158,10 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name] 'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), g1, indices, sampler, device=F.ctx(),
batch_size=g1.num_nodes(), batch_size=g1.num_nodes(),
num_workers=num_workers) num_workers=(num_workers if (pin_graph and F.ctx() == F.cpu()) else 0),
use_uva=use_uva)
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes) _check_device(input_nodes)
_check_device(output_nodes) _check_device(output_nodes)
...@@ -155,6 +169,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph): ...@@ -155,6 +169,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
_check_dtype(input_nodes, idtype, 'dtype') _check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype') _check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype') _check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned():
g1.unpin_memory_()
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]), ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
...@@ -164,6 +180,21 @@ def test_node_dataloader(idtype, sampler_name, pin_graph): ...@@ -164,6 +180,21 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
}).astype(idtype) }).astype(idtype)
for ntype in g2.ntypes: for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu()) g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
indices = {nty: F.arange(0, g2.num_nodes(nty)) for nty in g2.ntypes}
if F.ctx() != F.cpu():
if pin_graph:
g2.create_formats_()
g2.pin_memory_()
if pin_graph == 'cpu_indices':
indices = {nty: F.arange(0, g2.num_nodes(nty), idtype, F.cpu()) for nty in g2.ntypes}
elif pin_graph == 'cuda_indices':
if F._default_context_str == 'gpu':
indices = {nty: F.arange(0, g2.num_nodes(), idtype, F.cuda()) for nty in g2.ntypes}
else:
return # skip
else:
g2 = g2.to('cuda')
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes) batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
sampler = { sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
...@@ -172,7 +203,9 @@ def test_node_dataloader(idtype, sampler_name, pin_graph): ...@@ -172,7 +203,9 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes}, g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size) sampler, device=F.ctx(), batch_size=batch_size,
num_workers=(num_workers if (pin_graph and F.ctx() == F.cpu()) else 0),
use_uva=use_uva)
assert isinstance(iter(dataloader), Iterator) assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes) _check_device(input_nodes)
...@@ -182,8 +215,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph): ...@@ -182,8 +215,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
_check_dtype(output_nodes, idtype, 'dtype') _check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype') _check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned(): if g2.is_pinned():
g1.unpin_memory_() g2.unpin_memory_()
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor']) @pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
@pytest.mark.parametrize('neg_sampler', [ @pytest.mark.parametrize('neg_sampler', [
...@@ -349,9 +382,4 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag): ...@@ -349,9 +382,4 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
assert not np.isin(edges_to_exclude, block_eids).any() assert not np.isin(edges_to_exclude, block_eids).any()
if __name__ == '__main__': if __name__ == '__main__':
test_graph_dataloader() test_node_dataloader(F.int32, 'neighbor', None)
test_cluster_gcn(0)
test_neighbor_nonuniform(0)
for exclude in [None, 'self', 'reverse_id', 'reverse_types']:
test_edge_dataloader_excludes(exclude, False)
test_edge_dataloader_excludes(exclude, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment