Unverified Commit 0a87bc6a authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add DistributedItemSampler to support multi-gpu training (#6341)

parent e22bd78f
...@@ -4,6 +4,7 @@ from collections.abc import Mapping ...@@ -4,6 +4,7 @@ from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Callable, Iterator, Optional from typing import Callable, Iterator, Optional
import torch.distributed as dist
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
...@@ -14,7 +15,7 @@ from ..heterograph import DGLGraph ...@@ -14,7 +15,7 @@ from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch from .minibatch import MiniBatch
__all__ = ["ItemSampler", "minibatcher_default"] __all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"]
def minibatcher_default(batch, names): def minibatcher_default(batch, names):
...@@ -280,12 +281,30 @@ class ItemSampler(IterDataPipe): ...@@ -280,12 +281,30 @@ class ItemSampler(IterDataPipe):
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self._item_set = item_set self._names = item_set.names
self._item_set = IterableWrapper(item_set)
self._batch_size = batch_size self._batch_size = batch_size
self._minibatcher = minibatcher self._minibatcher = minibatcher
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
def _organize_items(self, data_pipe) -> None:
# Shuffle before batch.
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size = max(10000, 10 * self._batch_size)
data_pipe = data_pipe.shuffle(buffer_size=buffer_size)
# Batch.
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
)
return data_pipe
@staticmethod @staticmethod
def _collate(batch): def _collate(batch):
"""Collate items into a batch. For internal use only.""" """Collate items into a batch. For internal use only."""
...@@ -306,27 +325,153 @@ class ItemSampler(IterDataPipe): ...@@ -306,27 +325,153 @@ class ItemSampler(IterDataPipe):
return default_collate(batch) return default_collate(batch)
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
data_pipe = IterableWrapper(self._item_set) # Organize items.
# Shuffle before batch. data_pipe = self._organize_items(self._item_set)
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size = max(10000, 10 * self._batch_size)
data_pipe = data_pipe.shuffle(buffer_size=buffer_size)
# Batch.
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
)
# Collate. # Collate.
data_pipe = data_pipe.collate(collate_fn=self._collate) data_pipe = data_pipe.collate(collate_fn=self._collate)
# Map to minibatch. # Map to minibatch.
data_pipe = data_pipe.map( data_pipe = data_pipe.map(partial(self._minibatcher, names=self._names))
partial(self._minibatcher, names=self._item_set.names)
)
return iter(data_pipe) return iter(data_pipe)
class DistributedItemSampler(ItemSampler):
"""Distributed Item Sampler.
This sampler creates a distributed subset of items from the given data set,
which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is sharded such that each replica
(process) receives an exclusive subset.
Note: The items will be first sharded onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
Note: This class `DistributedItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
Parameters
----------
item_set : ItemSet or ItemSetDict
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
Option to shuffle before sample.
num_replicas: int
The number of model replicas that will be created during Distributed
Data Parallel (DDP) training. It should be the same as the real world
size, otherwise it could cause errors. By default, it is retrieved from
the current distributed group.
drop_uneven_inputs : bool
Option to make sure the numbers of batches for each replica are the
same. If some of the replicas have more batches than the others, the
redundant batches of those replicas will be dropped. If the drop_last
parameter is also set to True, the last batch will be dropped before the
redundant batches are dropped.
Note: When using Distributed Data Parallel (DDP) training, the program
may hang or error if the a replica has fewer inputs. It is recommended
to use the Join Context Manager provided by PyTorch to solve this
problem. Please refer to
https://pytorch.org/tutorials/advanced/generic_join.html. However, this
option can be used if the Join Context Manager is not helpful for any
reason.
Examples
--------
1. num_replica = 4, batch_size = 2, shuffle = False, drop_last = False,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 7, 8, 9]
- Replica#0 gets [[0, 4], [8]]
- Replica#1 gets [[1, 5], [9]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
2. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 7, 8, 9].
- Replica#0 gets [[0, 4]]
- Replica#1 gets [[1, 5]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
3. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4], [8, 12]]
- Replica#1 gets [[1, 5], [9, 13]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
3. num_replica = 4, batch_size = 2, shuffle = False, drop_last = False,
drop_uneven_inputs = True, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4], [8, 12]]
- Replica#1 gets [[1, 5], [9, 13]]
- Replica#2 gets [[2, 6], [10]]
- Replica#3 gets [[3, 7], [11]]
4. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = True, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4]]
- Replica#1 gets [[1, 5]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
5. num_replica = 4, batch_size = 2, shuffle = True, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 11, 12, 13].
One possible output:
- Replica#0 gets [[8, 0], [12, 4]]
- Replica#1 gets [[13, 1], [9, 5]]
- Replica#2 gets [[10, 2]]
- Replica#3 gets [[7, 11]]
"""
def __init__(
self,
item_set: ItemSet or ItemSetDict,
batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
num_replicas: Optional[int] = None,
drop_uneven_inputs: Optional[bool] = False,
) -> None:
super().__init__(item_set, batch_size, minibatcher, drop_last, shuffle)
self._drop_uneven_inputs = drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self._item_set = self._item_set.sharding_filter()
# Get world size.
if num_replicas is None:
assert (
dist.is_available()
), "Requires distributed package to be available."
num_replicas = dist.get_world_size()
if self._drop_uneven_inputs:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len = len(item_set)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self._num_evened_batches = total_len // (
num_replicas * batch_size
) + (
(not drop_last)
and (total_len % (num_replicas * batch_size) >= num_replicas)
)
def _organize_items(self, data_pipe) -> None:
data_pipe = super()._organize_items(data_pipe)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if self._drop_uneven_inputs:
data_pipe = data_pipe.header(self._num_evened_batches)
return data_pipe
import os
import re import re
from sys import platform
import dgl import dgl
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from dgl import graphbolt as gb from dgl import graphbolt as gb
from torch.testing import assert_close from torch.testing import assert_close
...@@ -610,3 +614,126 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -610,3 +614,126 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
def distributed_item_sampler_subprocess(
proc_id,
nprocs,
item_set,
num_ids,
batch_size,
shuffle,
drop_last,
drop_uneven_inputs,
):
# On Windows, the init method can only be file.
init_method = (
f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}"
if platform == "win32"
else "tcp://127.0.0.1:12345"
)
dist.init_process_group(
backend="gloo", # Use Gloo backend for CPU multiprocessing
init_method=init_method,
world_size=nprocs,
rank=proc_id,
)
# Create a DistributedItemSampler.
item_sampler = gb.DistributedItemSampler(
item_set,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
drop_uneven_inputs=drop_uneven_inputs,
)
feature_fetcher = gb.FeatureFetcher(
item_sampler,
gb.BasicFeatureStore({}),
[],
)
data_loader = gb.SingleProcessDataLoader(feature_fetcher)
# Count the numbers of items and batches.
num_items = 0
sampled_count = torch.zeros(num_ids, dtype=torch.int32)
for i in data_loader:
# Count how many times each item is sampled.
sampled_count[i.seed_nodes] += 1
num_items += i.seed_nodes.size(0)
num_batches = len(list(item_sampler))
# Calculate expected numbers of items and batches.
expected_num_items = num_ids // nprocs + (num_ids % nprocs > proc_id)
if drop_last and expected_num_items % batch_size > 0:
expected_num_items -= expected_num_items % batch_size
expected_num_batches = expected_num_items // batch_size + (
(not drop_last) and (expected_num_items % batch_size > 0)
)
if drop_uneven_inputs:
if (
(not drop_last)
and (num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % (nprocs * batch_size) > proc_id)
):
expected_num_batches -= 1
expected_num_items -= 1
elif (
drop_last
and (nprocs * batch_size - num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % nprocs > proc_id)
):
expected_num_batches -= 1
expected_num_items -= batch_size
num_batches_tensor = torch.tensor(num_batches)
dist.broadcast(num_batches_tensor, 0)
# Test if the number of batches are the same for all processes.
assert num_batches_tensor == num_batches
# Add up results from all processes.
dist.reduce(sampled_count, 0)
try:
# Check if the numbers are as expected.
assert num_items == expected_num_items
assert num_batches == expected_num_batches
# Make sure no item is sampled more than once.
assert sampled_count.max() <= 1
finally:
dist.destroy_process_group()
@pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("drop_uneven_inputs", [False, True])
def test_DistributedItemSampler(
num_ids, shuffle, drop_last, drop_uneven_inputs
):
nprocs = 4
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")
# On Windows, if the process group initialization file already exists,
# the program may hang. So we need to delete it if it exists.
if platform == "win32":
try:
os.remove(os.path.join(os.getcwd(), "dis_tempfile"))
except FileNotFoundError:
pass
mp.spawn(
distributed_item_sampler_subprocess,
args=(
nprocs,
item_set,
num_ids,
batch_size,
shuffle,
drop_last,
drop_uneven_inputs,
),
nprocs=nprocs,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment