Unverified Commit d79701db authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Use indexing in DistributedItemSampler (#6508)

parent 8f1b5782
......@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher:
batch_size: int,
drop_last: bool,
buffer_size: Optional[int] = 10 * 1000,
distributed: Optional[bool] = False,
drop_uneven_inputs: Optional[bool] = False,
world_size: Optional[int] = 1,
rank: Optional[int] = 0,
):
self._item_set = item_set
self._shuffle = shuffle
......@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher:
self._buffer_size = (
(self._buffer_size + batch_size - 1) // batch_size * batch_size
)
self._distributed = distributed
self._drop_uneven_inputs = drop_uneven_inputs
if distributed:
self._num_replicas = world_size
self._rank = rank
def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only."""
......@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher:
return torch.tensor(offsets)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers = worker_info.num_workers
worker_id = worker_info.id
else:
num_workers = 1
worker_id = 0
buffer = None
num_items = len(self._item_set)
if not self._distributed:
num_items = len(self._item_set)
start_offset = 0
else:
total_count = len(self._item_set)
big_batch_size = self._num_replicas * self._batch_size
big_batch_count, big_batch_remain = divmod(
total_count, big_batch_size
)
last_batch_count, batch_remain = divmod(
big_batch_remain, self._batch_size
)
if self._rank < last_batch_count:
last_batch = self._batch_size
elif self._rank == last_batch_count:
last_batch = batch_remain
else:
last_batch = 0
num_items = big_batch_count * self._batch_size + last_batch
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
)
if not self._drop_uneven_inputs or (
not self._drop_last and last_batch_count == self._num_replicas
):
# No need to drop uneven batches.
num_evened_items = num_items
if num_workers > 1:
total_batch_count = (
num_items + self._batch_size - 1
) // self._batch_size
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
num_items = (
min(num_items, split_num_items * (worker_id + 1))
- split_num_items * worker_id
)
num_evened_items = num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
)
else:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items = big_batch_count * self._batch_size
if num_workers > 1:
total_batch_count = big_batch_count
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
split_item_remain = last_batch // num_workers + (
worker_id < last_batch % num_workers
)
num_items = split_num_items + split_item_remain
num_evened_items = split_num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
+ last_batch // num_workers * worker_id
+ min(worker_id, last_batch % num_workers)
)
start = 0
while start < num_items:
end = min(start + self._buffer_size, num_items)
buffer = self._item_set[start:end]
buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
......@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher:
for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices):
break
if (
self._distributed
and self._drop_uneven_inputs
and i >= num_evened_items
):
break
batch_indices = indices[i : i + self._batch_size]
yield self._collate_batch(buffer, batch_indices, offsets)
buffer = None
......@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe):
self._drop_last = drop_last
self._shuffle = shuffle
self._use_indexing = use_indexing
self._distributed = False
self._drop_uneven_inputs = False
self._world_size = None
self._rank = None
def _organize_items(self, data_pipe) -> None:
# Shuffle before batch.
......@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe):
self._shuffle,
self._batch_size,
self._drop_last,
distributed=self._distributed,
drop_uneven_inputs=self._drop_uneven_inputs,
world_size=self._world_size,
rank=self._rank,
)
)
else:
......@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler):
which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is sharded such that each replica
counterparts. The original item set is split such that each replica
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first sharded onto each replica, then get shuffled
Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
......@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler):
Examples
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
......@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 13))
>>> item_set = gb.ItemSet(torch.arange(0, 14))
>>> num_replicas = 4
>>> batch_size = 2
>>> mp.spawn(...)
......@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler):
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
num_replicas: Optional[int] = None,
drop_uneven_inputs: Optional[bool] = False,
) -> None:
# [TODO][Rui] For now, always set use_indexing to False.
super().__init__(
item_set,
batch_size,
minibatcher,
drop_last,
shuffle,
use_indexing=False,
use_indexing=True,
)
self._distributed = True
self._drop_uneven_inputs = drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self._item_set = self._item_set.sharding_filter()
# Get world size.
if num_replicas is None:
assert (
dist.is_available()
), "Requires distributed package to be available."
num_replicas = dist.get_world_size()
if self._drop_uneven_inputs:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len = len(item_set)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self._num_evened_batches = total_len // (
num_replicas * batch_size
) + (
(not drop_last)
and (total_len % (num_replicas * batch_size) >= num_replicas)
if not dist.is_available():
raise RuntimeError(
"Distributed item sampler requires distributed package."
)
def _organize_items(self, data_pipe) -> None:
data_pipe = super()._organize_items(data_pipe)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if self._drop_uneven_inputs:
data_pipe = data_pipe.header(self._num_evened_batches)
return data_pipe
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()
......@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess(
for i in data_loader:
# Count how many times each item is sampled.
sampled_count[i.seed_nodes] += 1
if drop_last:
assert i.seed_nodes.size(0) == batch_size
num_items += i.seed_nodes.size(0)
num_batches = len(list(item_sampler))
# Calculate expected numbers of items and batches.
expected_num_items = num_ids // nprocs + (num_ids % nprocs > proc_id)
if drop_last and expected_num_items % batch_size > 0:
expected_num_items -= expected_num_items % batch_size
expected_num_batches = expected_num_items // batch_size + (
(not drop_last) and (expected_num_items % batch_size > 0)
)
if drop_uneven_inputs:
if (
(not drop_last)
and (num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % (nprocs * batch_size) > proc_id)
):
expected_num_batches -= 1
expected_num_items -= 1
elif (
drop_last
and (nprocs * batch_size - num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % nprocs > proc_id)
):
expected_num_batches -= 1
expected_num_items -= batch_size
num_batches_tensor = torch.tensor(num_batches)
dist.broadcast(num_batches_tensor, 0)
# Test if the number of batches are the same for all processes.
......@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess(
dist.reduce(sampled_count, 0)
try:
# Check if the numbers are as expected.
assert num_items == expected_num_items
assert num_batches == expected_num_batches
# Make sure no item is sampled more than once.
assert sampled_count.max() <= 1
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment