Unverified Commit d79701db authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Use indexing in DistributedItemSampler (#6508)

parent 8f1b5782
...@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher: ...@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher:
batch_size: int, batch_size: int,
drop_last: bool, drop_last: bool,
buffer_size: Optional[int] = 10 * 1000, buffer_size: Optional[int] = 10 * 1000,
distributed: Optional[bool] = False,
drop_uneven_inputs: Optional[bool] = False,
world_size: Optional[int] = 1,
rank: Optional[int] = 0,
): ):
self._item_set = item_set self._item_set = item_set
self._shuffle = shuffle self._shuffle = shuffle
...@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher: ...@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher:
self._buffer_size = ( self._buffer_size = (
(self._buffer_size + batch_size - 1) // batch_size * batch_size (self._buffer_size + batch_size - 1) // batch_size * batch_size
) )
self._distributed = distributed
self._drop_uneven_inputs = drop_uneven_inputs
if distributed:
self._num_replicas = world_size
self._rank = rank
def _collate_batch(self, buffer, indices, offsets=None): def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only.""" """Collate a batch from the buffer. For internal use only."""
...@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher: ...@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher:
return torch.tensor(offsets) return torch.tensor(offsets)
def __iter__(self): def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers = worker_info.num_workers
worker_id = worker_info.id
else:
num_workers = 1
worker_id = 0
buffer = None buffer = None
if not self._distributed:
num_items = len(self._item_set) num_items = len(self._item_set)
start_offset = 0
else:
total_count = len(self._item_set)
big_batch_size = self._num_replicas * self._batch_size
big_batch_count, big_batch_remain = divmod(
total_count, big_batch_size
)
last_batch_count, batch_remain = divmod(
big_batch_remain, self._batch_size
)
if self._rank < last_batch_count:
last_batch = self._batch_size
elif self._rank == last_batch_count:
last_batch = batch_remain
else:
last_batch = 0
num_items = big_batch_count * self._batch_size + last_batch
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
)
if not self._drop_uneven_inputs or (
not self._drop_last and last_batch_count == self._num_replicas
):
# No need to drop uneven batches.
num_evened_items = num_items
if num_workers > 1:
total_batch_count = (
num_items + self._batch_size - 1
) // self._batch_size
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
num_items = (
min(num_items, split_num_items * (worker_id + 1))
- split_num_items * worker_id
)
num_evened_items = num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
)
else:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items = big_batch_count * self._batch_size
if num_workers > 1:
total_batch_count = big_batch_count
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
split_item_remain = last_batch // num_workers + (
worker_id < last_batch % num_workers
)
num_items = split_num_items + split_item_remain
num_evened_items = split_num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
+ last_batch // num_workers * worker_id
+ min(worker_id, last_batch % num_workers)
)
start = 0 start = 0
while start < num_items: while start < num_items:
end = min(start + self._buffer_size, num_items) end = min(start + self._buffer_size, num_items)
buffer = self._item_set[start:end] buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start) indices = torch.arange(end - start)
if self._shuffle: if self._shuffle:
np.random.shuffle(indices.numpy()) np.random.shuffle(indices.numpy())
...@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher: ...@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher:
for i in range(0, len(indices), self._batch_size): for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices): if self._drop_last and i + self._batch_size > len(indices):
break break
if (
self._distributed
and self._drop_uneven_inputs
and i >= num_evened_items
):
break
batch_indices = indices[i : i + self._batch_size] batch_indices = indices[i : i + self._batch_size]
yield self._collate_batch(buffer, batch_indices, offsets) yield self._collate_batch(buffer, batch_indices, offsets)
buffer = None buffer = None
...@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe): ...@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe):
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
self._use_indexing = use_indexing self._use_indexing = use_indexing
self._distributed = False
self._drop_uneven_inputs = False
self._world_size = None
self._rank = None
def _organize_items(self, data_pipe) -> None: def _organize_items(self, data_pipe) -> None:
# Shuffle before batch. # Shuffle before batch.
...@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe): ...@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe):
self._shuffle, self._shuffle,
self._batch_size, self._batch_size,
self._drop_last, self._drop_last,
distributed=self._distributed,
drop_uneven_inputs=self._drop_uneven_inputs,
world_size=self._world_size,
rank=self._rank,
) )
) )
else: else:
...@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler): ...@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler):
which can be used for training with PyTorch's Distributed Data Parallel which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node (DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is sharded such that each replica counterparts. The original item set is split such that each replica
(process) receives an exclusive subset. (process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it. or another datapipe on it.
Note: The items will be first sharded onto each replica, then get shuffled Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set (if needed) and batched. Therefore, each replica will always get a same set
of items. of items.
...@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler):
Examples Examples
-------- --------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to 0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always executing following examples. Due to randomness, the output is not always
...@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 13)) >>> item_set = gb.ItemSet(torch.arange(0, 14))
>>> num_replicas = 4 >>> num_replicas = 4
>>> batch_size = 2 >>> batch_size = 2
>>> mp.spawn(...) >>> mp.spawn(...)
...@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler): ...@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler):
minibatcher: Optional[Callable] = minibatcher_default, minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
num_replicas: Optional[int] = None,
drop_uneven_inputs: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False,
) -> None: ) -> None:
# [TODO][Rui] For now, always set use_indexing to False.
super().__init__( super().__init__(
item_set, item_set,
batch_size, batch_size,
minibatcher, minibatcher,
drop_last, drop_last,
shuffle, shuffle,
use_indexing=False, use_indexing=True,
) )
self._distributed = True
self._drop_uneven_inputs = drop_uneven_inputs self._drop_uneven_inputs = drop_uneven_inputs
# Apply a sharding filter to distribute the items. if not dist.is_available():
self._item_set = self._item_set.sharding_filter() raise RuntimeError(
# Get world size. "Distributed item sampler requires distributed package."
if num_replicas is None:
assert (
dist.is_available()
), "Requires distributed package to be available."
num_replicas = dist.get_world_size()
if self._drop_uneven_inputs:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len = len(item_set)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self._num_evened_batches = total_len // (
num_replicas * batch_size
) + (
(not drop_last)
and (total_len % (num_replicas * batch_size) >= num_replicas)
) )
self._world_size = dist.get_world_size()
def _organize_items(self, data_pipe) -> None: self._rank = dist.get_rank()
data_pipe = super()._organize_items(data_pipe)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if self._drop_uneven_inputs:
data_pipe = data_pipe.header(self._num_evened_batches)
return data_pipe
...@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess( ...@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess(
for i in data_loader: for i in data_loader:
# Count how many times each item is sampled. # Count how many times each item is sampled.
sampled_count[i.seed_nodes] += 1 sampled_count[i.seed_nodes] += 1
if drop_last:
assert i.seed_nodes.size(0) == batch_size
num_items += i.seed_nodes.size(0) num_items += i.seed_nodes.size(0)
num_batches = len(list(item_sampler)) num_batches = len(list(item_sampler))
# Calculate expected numbers of items and batches.
expected_num_items = num_ids // nprocs + (num_ids % nprocs > proc_id)
if drop_last and expected_num_items % batch_size > 0:
expected_num_items -= expected_num_items % batch_size
expected_num_batches = expected_num_items // batch_size + (
(not drop_last) and (expected_num_items % batch_size > 0)
)
if drop_uneven_inputs: if drop_uneven_inputs:
if (
(not drop_last)
and (num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % (nprocs * batch_size) > proc_id)
):
expected_num_batches -= 1
expected_num_items -= 1
elif (
drop_last
and (nprocs * batch_size - num_ids % (nprocs * batch_size) < nprocs)
and (num_ids % nprocs > proc_id)
):
expected_num_batches -= 1
expected_num_items -= batch_size
num_batches_tensor = torch.tensor(num_batches) num_batches_tensor = torch.tensor(num_batches)
dist.broadcast(num_batches_tensor, 0) dist.broadcast(num_batches_tensor, 0)
# Test if the number of batches are the same for all processes. # Test if the number of batches are the same for all processes.
...@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess( ...@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess(
dist.reduce(sampled_count, 0) dist.reduce(sampled_count, 0)
try: try:
# Check if the numbers are as expected.
assert num_items == expected_num_items
assert num_batches == expected_num_batches
# Make sure no item is sampled more than once. # Make sure no item is sampled more than once.
assert sampled_count.max() <= 1 assert sampled_count.max() <= 1
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment