Unverified Commit 68275b9f authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Feature] Decompose dataset to each process to add multi-node support for dataloader (#5617)

parent 70af8f0d
...@@ -24,7 +24,6 @@ from ..batch import batch as batch_graphs ...@@ -24,7 +24,6 @@ from ..batch import batch as batch_graphs
from ..distributed import DistGraph from ..distributed import DistGraph
from ..frame import LazyFeature from ..frame import LazyFeature
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from ..multiprocessing import call_once_and_share
from ..storages import wrap_storage from ..storages import wrap_storage
from ..utils import ( from ..utils import (
dtype_of, dtype_of,
...@@ -121,6 +120,52 @@ def _get_id_tensor_from_mapping(indices, device, keys): ...@@ -121,6 +120,52 @@ def _get_id_tensor_from_mapping(indices, device, keys):
return id_tensor return id_tensor
def _split_to_local_id_tensor_from_mapping(
indices, keys, local_lower_bound, local_upper_bound
):
dtype = dtype_of(indices)
device = next(iter(indices.values())).device
num_samples = local_upper_bound - local_lower_bound
id_tensor = torch.empty(num_samples, 2, dtype=dtype, device=device)
index_offset = 0
split_id_offset = 0
for i, k in enumerate(keys):
if k not in indices:
continue
index = indices[k]
length = index.shape[0]
index_offset2 = index_offset + length
lower = max(local_lower_bound, index_offset)
upper = min(local_upper_bound, index_offset2)
if upper > lower:
split_id_offset2 = split_id_offset + (upper - lower)
assert split_id_offset2 <= num_samples
id_tensor[split_id_offset:split_id_offset2, 0] = i
id_tensor[split_id_offset:split_id_offset2, 1] = index[
lower - index_offset : upper - index_offset
]
split_id_offset += upper - lower
if split_id_offset2 == num_samples:
break
index_offset = index_offset2
return id_tensor
def _split_to_local_id_tensor(indices, local_lower_bound, local_upper_bound):
dtype = dtype_of(indices)
device = indices.device
num_samples = local_upper_bound - local_lower_bound
id_tensor = torch.empty(num_samples, dtype=dtype, device=device)
if local_upper_bound > len(indices):
remainder = len(indices) - local_lower_bound
id_tensor[0:remainder] = indices[local_lower_bound:]
else:
id_tensor = indices[local_lower_bound:local_upper_bound]
return id_tensor
def _divide_by_worker(dataset, batch_size, drop_last): def _divide_by_worker(dataset, batch_size, drop_last):
num_samples = dataset.shape[0] num_samples = dataset.shape[0]
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
...@@ -194,6 +239,16 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -194,6 +239,16 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
) // self.batch_size ) // self.batch_size
def _decompose_one_dimension(length, world_size, rank, drop_last):
if drop_last:
num_samples = math.floor(length / world_size)
else:
num_samples = math.ceil(length / world_size)
sta = rank * num_samples
end = (rank + 1) * num_samples
return sta, end
class DDPTensorizedDataset(torch.utils.data.IterableDataset): class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
...@@ -217,64 +272,54 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -217,64 +272,54 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
(
if self.drop_last and len_indices % self.num_replicas != 0: self.local_lower_bound,
self.num_samples = math.ceil( self.local_upper_bound,
(len_indices - self.num_replicas) / self.num_replicas ) = _decompose_one_dimension(
) len_indices, self.num_replicas, self.rank, drop_last
else:
self.num_samples = math.ceil(len_indices / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
# If drop_last is True, we create a shared memory array larger than the number
# of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later.
self.shared_mem_size = (
self.total_size if not self.drop_last else len_indices
) )
self.num_indices = len_indices self.num_samples = self.local_upper_bound - self.local_lower_bound
self.local_num_indices = self.num_samples
if self.local_upper_bound > len_indices:
assert not drop_last
self.local_num_indices = len_indices - self.local_lower_bound
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device self._id_tensor = _split_to_local_id_tensor_from_mapping(
self._id_tensor = call_once_and_share( indices,
lambda: _get_id_tensor_from_mapping( self._mapping_keys,
indices, self._device, self._mapping_keys self.local_lower_bound,
), self.local_upper_bound,
(self.num_indices, 2),
dtype_of(indices),
) )
else: else:
self._id_tensor = indices self._id_tensor = _split_to_local_id_tensor(
self._device = self._id_tensor.device indices, self.local_lower_bound, self.local_upper_bound
)
self._indices = call_once_and_share( self._device = self._id_tensor.device
self._create_shared_indices, (self.shared_mem_size,), torch.int64 # padding self._indices when drop_last = False (self._indices always on cpu)
self._indices = torch.empty(self.num_samples, dtype=torch.int64)
torch.arange(
self.local_num_indices, out=self._indices[: self.local_num_indices]
) )
if not drop_last:
def _create_shared_indices(self): torch.arange(
indices = torch.empty(self.shared_mem_size, dtype=torch.int64) self.num_samples - self.local_num_indices,
num_ids = self._id_tensor.shape[0] out=self._indices[self.local_num_indices :],
torch.arange(num_ids, out=indices[:num_ids]) )
torch.arange(self.shared_mem_size - num_ids, out=indices[num_ids:]) assert len(self._id_tensor) == self.num_samples
return indices
def shuffle(self): def shuffle(self):
"""Shuffles the dataset.""" """Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it. np.random.shuffle(self._indices[: self.local_num_indices].numpy())
if self.rank == 0: if not self.drop_last:
np.random.shuffle(self._indices[: self.num_indices].numpy()) # pad extra from local indices
if not self.drop_last: self._indices[self.local_num_indices :] = self._indices[
# pad extra : self.num_samples - self.local_num_indices
self._indices[self.num_indices :] = self._indices[ ]
: self.total_size - self.num_indices
]
dist.barrier()
def __iter__(self): def __iter__(self):
start = self.num_samples * self.rank
end = self.num_samples * (self.rank + 1)
indices = _divide_by_worker( indices = _divide_by_worker(
self._indices[start:end], self.batch_size, self.drop_last self._indices, self.batch_size, self.drop_last
) )
id_tensor = self._id_tensor[indices] id_tensor = self._id_tensor[indices]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from utils import parametrize_idtype from utils import parametrize_idtype
...@@ -222,6 +223,102 @@ def _check_device(data): ...@@ -222,6 +223,102 @@ def _check_device(data):
assert data.device == F.ctx() assert data.device == F.ctx()
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
"mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("nprocs", [1, 4])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ddp_dataloader_decompose_dataset(
sampler_name, mode, nprocs, drop_last
):
if torch.cuda.device_count() < nprocs and mode != "cpu":
pytest.skip(
"DDP dataloader needs sufficient GPUs for UVA and GPU sampling."
)
if mode != "cpu" and F.ctx() == F.cpu():
pytest.skip("UVA and GPU sampling require a GPU.")
if os.name == "nt":
pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
g, _, _, _ = _create_homogeneous()
g = g.to(F.cpu())
sampler = {
"full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
"neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
}[sampler_name]
indices = F.copy_to(F.arange(0, g.num_nodes()), F.cpu())
data = indices, sampler
arguments = mode, drop_last
g.create_formats_()
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
mp.spawn(_ddp_runner, args=(nprocs, g, data, arguments), nprocs=nprocs)
def _ddp_runner(proc_id, nprocs, g, data, args):
mode, drop_last = args
indices, sampler = data
if mode == "cpu":
device = torch.device("cpu")
else:
device = torch.device(proc_id)
torch.cuda.set_device(device)
if mode == "pure_gpu":
g = g.to(F.cuda())
if mode in ("cpu", "uva_cpu_indices"):
indices = indices.cpu()
else:
indices = indices.cuda()
dist.init_process_group(
"nccl" if mode != "cpu" else "gloo",
"tcp://127.0.0.1:12347",
world_size=nprocs,
rank=proc_id,
)
use_uva = mode.startswith("uva")
batch_size = g.num_nodes()
shuffle = False
for num_workers in [1, 4] if mode == "cpu" else [0]:
dataloader = dgl.dataloading.DataLoader(
g,
indices,
sampler,
device=device,
batch_size=batch_size, # g1.num_nodes(),
num_workers=num_workers,
use_uva=use_uva,
use_ddp=True,
drop_last=drop_last,
shuffle=shuffle,
)
max_nid = [0]
for i, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
max_nid.append(np.max(dst_nodes_id.cpu().numpy()))
local_max = torch.tensor(np.max(max_nid))
if torch.distributed.get_backend() == "nccl":
local_max = local_max.cuda()
dist.reduce(local_max, 0, op=dist.ReduceOp.MAX)
if proc_id == 0:
if drop_last and not shuffle and local_max > 0:
assert (
local_max.item()
== len(indices)
- len(indices) % nprocs
- 1
- (len(indices) // nprocs) % batch_size
)
elif not drop_last:
assert local_max == len(indices) - 1
dist.destroy_process_group()
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"]) @pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment