Unverified Commit e06e63d5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Enable UVA sampling with CPU indices (#3892)

* enable UVA sampling with CPU indices

* add docs

* add more docs

* lint

* fix

* fix

* better error message

* use mp.Barrier instead of queues

* revert

* revert

* oops

* revert dgl.multiprocessing.spawn

* Update pytorch.py
parent 0d878ff8
......@@ -12,3 +12,11 @@ additional documentation.
In addition, if your backend is PyTorch, this module will also be compatible with
:mod:`torch.multiprocessing` module.
.. currentmodule:: dgl.multiprocessing.pytorch
.. autosummary::
:toctree: ../../generated/
spawn
call_once_and_share
shared_tensor
......@@ -6,45 +6,14 @@ import torch.distributed.optim
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
from dgl.utils import pin_memory_inplace, unpin_memory_inplace, \
gather_pinned_tensor_rows, create_shared_mem_array, get_shared_mem_array
from dgl.utils import pin_memory_inplace, unpin_memory_inplace
from dgl.multiprocessing import shared_tensor
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
def shared_tensor(*shape, device, name, dtype=torch.float32):
""" Create a tensor in shared memroy, pinned in each process's CUDA
context.
Parameters
----------
shape : int...
A sequence of integers describing the shape of the new tensor.
device : context
The device of the result tensor.
name : string
The name of the shared allocation.
dtype : dtype, optional
The datatype of the allocation. Default: torch.float32
Returns
-------
Tensor :
The shared tensor.
"""
rank = dist.get_rank()
if rank == 0:
y = create_shared_mem_array(
name, shape, dtype)
dist.barrier()
if rank != 0:
y = get_shared_mem_array(name, shape, dtype)
pin_memory_inplace(y)
return y
class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
......@@ -105,9 +74,8 @@ class SAGE(nn.Module):
# shared output tensor 'y' in host memory, pin it to allow UVA
# access from each GPU during forward propagation.
y = shared_tensor(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device='cpu', name='layer_{}_output'.format(l))
(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes))
pin_memory_inplace(y)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
if dist.get_rank() == 0 else dataloader:
......
import torch
import argparse
import dgl
import dgl.multiprocessing as mp
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
......
......@@ -4,8 +4,8 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
def init_emb2pos_index(walk_length, window_size, batch_size):
......
import torch
import argparse
import dgl
import dgl.multiprocessing as mp
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
......
......@@ -4,8 +4,8 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
def init_emb2neg_index(negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings
......
......@@ -6,8 +6,8 @@ import argparse
import gc
import torch as th
import torch.nn.functional as F
import torch.multiprocessing as mp
import dgl
import torch.multiprocessing as mp
from torchmetrics.functional import accuracy
from torch.nn.parallel import DistributedDataParallel
......
......@@ -4,7 +4,6 @@ from queue import Queue, Empty, Full
import itertools
import threading
from distutils.version import LooseVersion
import random
import math
import inspect
import re
......@@ -18,15 +17,15 @@ from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning
from ..batch import batch as batch_graphs
from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array, context_of, dtype_of)
context_of, dtype_of)
from ..frame import LazyFeature
from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler
from .. import backend as F
from ..distributed import DistGraph
from ..multiprocessing import call_once_and_share
PYTHON_EXIT_STATUS = False
def _set_python_exit_flag():
......@@ -87,12 +86,19 @@ class _TensorizedDatasetIter(object):
def _get_id_tensor_from_mapping(indices, device, keys):
dtype = dtype_of(indices)
lengths = torch.tensor(
[(indices[k].shape[0] if k in indices else 0) for k in keys],
dtype=dtype, device=device)
type_ids = torch.arange(len(keys), dtype=dtype, device=device).repeat_interleave(lengths)
all_indices = torch.cat([indices[k] for k in keys if k in indices])
return torch.stack([type_ids, all_indices], 1)
id_tensor = torch.empty(
sum(v.shape[0] for v in indices.values()), 2, dtype=dtype, device=device)
offset = 0
for i, k in enumerate(keys):
if k not in indices:
continue
index = indices[k]
length = index.shape[0]
id_tensor[offset:offset+length, 0] = i
id_tensor[offset:offset+length, 1] = index
offset += length
return id_tensor
def _divide_by_worker(dataset, batch_size, drop_last):
......@@ -115,7 +121,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
When the dataset is on the GPU, this significantly reduces the overhead.
"""
def __init__(self, indices, batch_size, drop_last):
name, _ = _generate_shared_mem_name_id()
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device
......@@ -128,12 +133,10 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch.
self._indices = create_shared_mem_array(name, (self._id_tensor.shape[0],), torch.int64)
self._indices = torch.empty(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self._indices[:] = torch.arange(self._id_tensor.shape[0])
self.batch_size = batch_size
self.drop_last = drop_last
self.shared_mem_name = name
self.shared_mem_size = self._indices.shape[0]
def shuffle(self):
"""Shuffle the dataset."""
......@@ -150,17 +153,6 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
num_samples = self._id_tensor.shape[0]
return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size
def _get_shared_mem_name(id_):
return f'ddp_{id_}'
def _generate_shared_mem_name_id():
for _ in range(3): # 3 trials
id_ = random.getrandbits(32)
name = _get_shared_mem_name(id_)
if not nd.exist_shared_mem_array(name):
return name, id_
raise DGLError('Unable to generate a shared memory array')
class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead.
......@@ -197,33 +189,22 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device
self._id_tensor = _get_id_tensor_from_mapping(
indices, self._device, self._mapping_keys)
self._id_tensor = call_once_and_share(
lambda: _get_id_tensor_from_mapping(indices, self._device, self._mapping_keys),
(self.num_indices, 2), dtype_of(indices))
else:
self._id_tensor = indices
self._device = self._id_tensor.device
if self.rank == 0:
name, id_ = _generate_shared_mem_name_id()
self._indices = create_shared_mem_array(
name, (self.shared_mem_size,), torch.int64)
self._indices[:self._id_tensor.shape[0]] = torch.arange(self._id_tensor.shape[0])
meta_info = torch.LongTensor([id_, self._indices.shape[0]])
else:
meta_info = torch.LongTensor([0, 0])
if dist.get_backend() == 'nccl':
# Use default CUDA device; PyTorch DDP required the users to set the CUDA
# device for each process themselves so calling .cuda() should be safe.
meta_info = meta_info.cuda()
dist.broadcast(meta_info, src=0)
self._indices = call_once_and_share(
self._create_shared_indices, (self.shared_mem_size,), torch.int64)
if self.rank != 0:
id_, num_samples = meta_info.tolist()
name = _get_shared_mem_name(id_)
indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64)
self._indices = indices_shared
self.shared_mem_name = name
def _create_shared_indices(self):
indices = torch.empty(self.shared_mem_size, dtype=torch.int64)
num_ids = self._id_tensor.shape[0]
indices[:num_ids] = torch.arange(num_ids)
indices[num_ids:] = torch.arange(self.shared_mem_size - num_ids)
return indices
def shuffle(self):
"""Shuffles the dataset."""
......@@ -525,11 +506,16 @@ class CollateWrapper(object):
"""Wraps a collate function with :func:`remove_parent_storage_columns` for serializing
from PyTorch DataLoader workers.
"""
def __init__(self, sample_func, g):
def __init__(self, sample_func, g, use_uva, device):
self.sample_func = sample_func
self.g = g
self.use_uva = use_uva
self.device = device
def __call__(self, items):
if self.use_uva:
# Only copy the indices to the given device if in UVA mode.
items = recursive_apply(items, lambda x: x.to(self.device))
batch = self.sample_func(self.g, items)
return recursive_apply(batch, remove_parent_storage_columns, self.g)
......@@ -771,10 +757,6 @@ class DataLoader(torch.utils.data.DataLoader):
if use_uva:
if self.graph.device.type != 'cpu':
raise ValueError('Graph must be on CPU if UVA sampling is enabled.')
if indices_device != self.device:
raise ValueError(
f'Indices must be on the same device as the device argument '
f'({self.device})')
if num_workers > 0:
raise ValueError('num_workers must be 0 if UVA sampling is enabled.')
......@@ -845,7 +827,8 @@ class DataLoader(torch.utils.data.DataLoader):
super().__init__(
self.dataset,
collate_fn=CollateWrapper(self.graph_sampler.sample, graph),
collate_fn=CollateWrapper(
self.graph_sampler.sample, graph, self.use_uva, self.device),
batch_size=None,
worker_init_fn=worker_init_fn,
**kwargs)
......
......@@ -9,8 +9,8 @@ from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
# Wrap around torch.multiprocessing...
from torch.multiprocessing import *
# ... and override the Process initializer
from .pytorch import Process
# ... and override the Process initializer and spawn function.
from .pytorch import *
else:
# Just import multiprocessing module.
from multiprocessing import * # pylint: disable=redefined-builtin
"""PyTorch multiprocessing wrapper."""
from functools import wraps
import random
import traceback
from _thread import start_new_thread
import torch
import torch.multiprocessing as mp
from ..utils import create_shared_mem_array, get_shared_mem_array
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
......@@ -35,3 +39,70 @@ class Process(mp.Process):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon)
def _get_shared_mem_name(id_):
return "shared" + str(id_)
def call_once_and_share(func, shape, dtype, rank=0):
"""Invoke the function in a single process of the process group spawned by
:func:`spawn`, and share the result to other processes.
Requires the subprocesses to be spawned with :func:`dgl.multiprocessing.pytorch.spawn`.
Parameters
----------
func : callable
Any callable that accepts no arguments and returns an arbitrary object.
shape : tuple[int]
The shape of the shared tensor. Must match the output of :attr:`func`.
dtype : torch.dtype
The data type of the shared tensor. Must match the output of :attr:`func`.
rank : int, optional
The process ID to actually execute the function.
"""
current_rank = torch.distributed.get_rank()
dist_buf = torch.LongTensor([1])
if torch.distributed.get_backend() == 'nccl':
# Use .cuda() to transfer it to the correct device. Should be OK since
# PyTorch recommends the users to call set_device() after getting inside
# torch.multiprocessing.spawn()
dist_buf = dist_buf.cuda()
# Process with the given rank creates and populates the shared memory array.
if current_rank == rank:
id_ = random.getrandbits(32)
name = _get_shared_mem_name(id_)
result = create_shared_mem_array(name, shape, dtype)
result[:] = func()
dist_buf[0] = id_
# Broadcasts the name of the shared array to other processes.
torch.distributed.broadcast(dist_buf, rank)
# If no exceptions, other processes open the same shared memory object.
if current_rank != rank:
id_ = dist_buf.item()
name = _get_shared_mem_name(id_)
result = get_shared_mem_array(name, shape, dtype)
return result
def shared_tensor(shape, dtype=torch.float32):
"""Create a tensor in shared memory accessible by all processes within the same
``torch.distsributed`` process group.
The content is uninitialized.
Parameters
----------
shape : tuple[int]
The shape of the tensor.
dtype : torch.dtype, optional
The dtype of the tensor.
Returns
-------
Tensor
The shared tensor.
"""
return call_once_and_share(lambda: torch.empty(*shape, dtype=dtype), shape, dtype)
......@@ -51,8 +51,8 @@ def randn(shape):
def tensor(data, dtype=None):
return copy_to(_tensor(data, dtype), _default_context)
def arange(start, stop, dtype=int64):
return copy_to(_arange(start, stop, dtype), _default_context)
def arange(start, stop, dtype=int64, ctx=None):
return _arange(start, stop, dtype, ctx if ctx is not None else _default_context)
def full(shape, fill_value, dtype, ctx=_default_context):
return _full(shape, fill_value, dtype, ctx)
......
......@@ -130,14 +130,27 @@ def _check_device(data):
@parametrize_dtype
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [False, True])
@pytest.mark.parametrize('pin_graph', [None, 'cuda_indices', 'cpu_indices'])
def test_node_dataloader(idtype, sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
if F.ctx() != F.cpu() and pin_graph:
g1.create_formats_()
g1.pin_memory_()
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
indices = F.arange(0, g1.num_nodes(), idtype)
if F.ctx() != F.cpu():
if pin_graph:
g1.create_formats_()
g1.pin_memory_()
if pin_graph == 'cpu_indices':
indices = F.arange(0, g1.num_nodes(), idtype, F.cpu())
elif pin_graph == 'cuda_indices':
if F._default_context_str == 'gpu':
indices = F.arange(0, g1.num_nodes(), idtype, F.cuda())
else:
return # skip
else:
g1 = g1.to('cuda')
use_uva = pin_graph is not None and F.ctx() != F.cpu()
for num_workers in [0, 1, 2]:
sampler = {
......@@ -145,9 +158,10 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(),
g1, indices, sampler, device=F.ctx(),
batch_size=g1.num_nodes(),
num_workers=num_workers)
num_workers=(num_workers if (pin_graph and F.ctx() == F.cpu()) else 0),
use_uva=use_uva)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
......@@ -155,6 +169,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned():
g1.unpin_memory_()
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
......@@ -164,6 +180,21 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
}).astype(idtype)
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
indices = {nty: F.arange(0, g2.num_nodes(nty)) for nty in g2.ntypes}
if F.ctx() != F.cpu():
if pin_graph:
g2.create_formats_()
g2.pin_memory_()
if pin_graph == 'cpu_indices':
indices = {nty: F.arange(0, g2.num_nodes(nty), idtype, F.cpu()) for nty in g2.ntypes}
elif pin_graph == 'cuda_indices':
if F._default_context_str == 'gpu':
indices = {nty: F.arange(0, g2.num_nodes(), idtype, F.cuda()) for nty in g2.ntypes}
else:
return # skip
else:
g2 = g2.to('cuda')
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
sampler = {
'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
......@@ -172,7 +203,9 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
sampler, device=F.ctx(), batch_size=batch_size,
num_workers=(num_workers if (pin_graph and F.ctx() == F.cpu()) else 0),
use_uva=use_uva)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
......@@ -182,8 +215,8 @@ def test_node_dataloader(idtype, sampler_name, pin_graph):
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned():
g1.unpin_memory_()
if g2.is_pinned():
g2.unpin_memory_()
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
@pytest.mark.parametrize('neg_sampler', [
......@@ -349,9 +382,4 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
assert not np.isin(edges_to_exclude, block_eids).any()
if __name__ == '__main__':
test_graph_dataloader()
test_cluster_gcn(0)
test_neighbor_nonuniform(0)
for exclude in [None, 'self', 'reverse_id', 'reverse_types']:
test_edge_dataloader_excludes(exclude, False)
test_edge_dataloader_excludes(exclude, True)
test_node_dataloader(F.int32, 'neighbor', None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment