Unverified Commit 46d7b1d9 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Rewrite DistributeItemSampler logic (#6565)

parent 81c7781b
...@@ -16,6 +16,7 @@ from ..batch import batch as dgl_batch ...@@ -16,6 +16,7 @@ from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch from .minibatch import MiniBatch
from .utils import calculate_range
__all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"] __all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"]
...@@ -125,7 +126,6 @@ class ItemShufflerAndBatcher: ...@@ -125,7 +126,6 @@ class ItemShufflerAndBatcher:
) )
self._distributed = distributed self._distributed = distributed
self._drop_uneven_inputs = drop_uneven_inputs self._drop_uneven_inputs = drop_uneven_inputs
if distributed:
self._num_replicas = world_size self._num_replicas = world_size
self._rank = rank self._rank = rank
...@@ -184,101 +184,33 @@ class ItemShufflerAndBatcher: ...@@ -184,101 +184,33 @@ class ItemShufflerAndBatcher:
num_workers = 1 num_workers = 1
worker_id = 0 worker_id = 0
buffer = None buffer = None
if not self._distributed: total = len(self._item_set)
num_items = len(self._item_set) start_offset, assigned_count, output_count = calculate_range(
start_offset = 0 self._distributed,
else: total,
total_count = len(self._item_set) self._num_replicas,
big_batch_size = self._num_replicas * self._batch_size self._rank,
big_batch_count, big_batch_remain = divmod( num_workers,
total_count, big_batch_size worker_id,
) self._batch_size,
last_batch_count, batch_remain = divmod( self._drop_last,
big_batch_remain, self._batch_size self._drop_uneven_inputs,
)
if self._rank < last_batch_count:
last_batch = self._batch_size
elif self._rank == last_batch_count:
last_batch = batch_remain
else:
last_batch = 0
num_items = big_batch_count * self._batch_size + last_batch
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
)
if not self._drop_uneven_inputs or (
not self._drop_last and last_batch_count == self._num_replicas
):
# No need to drop uneven batches.
num_evened_items = num_items
if num_workers > 1:
total_batch_count = (
num_items + self._batch_size - 1
) // self._batch_size
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
num_items = (
min(num_items, split_num_items * (worker_id + 1))
- split_num_items * worker_id
)
num_evened_items = num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
)
else:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items = big_batch_count * self._batch_size
if num_workers > 1:
total_batch_count = big_batch_count
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
split_item_remain = last_batch // num_workers + (
worker_id < last_batch % num_workers
)
num_items = split_num_items + split_item_remain
num_evened_items = split_num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
+ last_batch // num_workers * worker_id
+ min(worker_id, last_batch % num_workers)
) )
start = 0 start = 0
while start < num_items: while start < assigned_count:
end = min(start + self._buffer_size, num_items) end = min(start + self._buffer_size, assigned_count)
buffer = self._item_set[start_offset + start : start_offset + end] buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start) indices = torch.arange(end - start)
if self._shuffle: if self._shuffle:
np.random.shuffle(indices.numpy()) np.random.shuffle(indices.numpy())
offsets = self._calculate_offsets(buffer) offsets = self._calculate_offsets(buffer)
for i in range(0, len(indices), self._batch_size): for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices): if output_count <= 0:
break
if (
self._distributed
and self._drop_uneven_inputs
and i >= num_evened_items
):
break break
batch_indices = indices[i : i + self._batch_size] batch_indices = indices[
i : i + min(self._batch_size, output_count)
]
output_count -= self._batch_size
yield self._collate_batch(buffer, batch_indices, offsets) yield self._collate_batch(buffer, batch_indices, offsets)
buffer = None buffer = None
start = end start = end
...@@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler): ...@@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler):
counterparts. The original item set is split such that each replica counterparts. The original item set is split such that each replica
(process) receives an exclusive subset. (process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first split onto each replica, then get shuffled Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set (if needed) and batched. Therefore, each replica will always get a same set
of items. of items.
...@@ -638,7 +566,6 @@ class DistributedItemSampler(ItemSampler): ...@@ -638,7 +566,6 @@ class DistributedItemSampler(ItemSampler):
Examples Examples
-------- --------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to 0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always executing following examples. Due to randomness, the output is not always
...@@ -646,7 +573,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -646,7 +573,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 14)) >>> item_set = gb.ItemSet(torch.arange(15))
>>> num_replicas = 4 >>> num_replicas = 4
>>> batch_size = 2 >>> batch_size = 2
>>> mp.spawn(...) >>> mp.spawn(...)
...@@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler):
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([2, 6]), tensor([10])] Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([3, 7]), tensor([11])] Replica#3: [tensor([12, 13]), tensor([14])]
2. shuffle = False, drop_last = True, drop_uneven_inputs = False. 2. shuffle = False, drop_last = True, drop_uneven_inputs = False.
...@@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler):
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([2, 6])] Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([3, 7])] Replica#3: [tensor([12, 13])]
3. shuffle = False, drop_last = False, drop_uneven_inputs = True. 3. shuffle = False, drop_last = False, drop_uneven_inputs = True.
...@@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler):
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([2, 6]), tensor([10])] Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([3, 7]), tensor([11])] Replica#3: [tensor([12, 13]), tensor([14])]
4. shuffle = False, drop_last = True, drop_uneven_inputs = True. 4. shuffle = False, drop_last = True, drop_uneven_inputs = True.
...@@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler):
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4])] Replica#0: [tensor([0, 1])]
Replica#1: [tensor([1, 5])] Replica#1: [tensor([4, 5])]
Replica#2: [tensor([2, 6])] Replica#2: [tensor([8, 9])]
Replica#3: [tensor([3, 7])] Replica#3: [tensor([12, 13])]
5. shuffle = True, drop_last = True, drop_uneven_inputs = False. 5. shuffle = True, drop_last = True, drop_uneven_inputs = False.
...@@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:) (One possible output:)
Replica#0: [tensor([0, 8]), tensor([ 4, 12])] Replica#0: [tensor([3, 2]), tensor([0, 1])]
Replica#1: [tensor([ 5, 13]), tensor([9, 1])] Replica#1: [tensor([6, 5]), tensor([7, 4])]
Replica#2: [tensor([ 2, 10])] Replica#2: [tensor([8, 10])]
Replica#3: [tensor([11, 7])] Replica#3: [tensor([14, 12])]
6. shuffle = True, drop_last = True, drop_uneven_inputs = True. 6. shuffle = True, drop_last = True, drop_uneven_inputs = True.
...@@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler): ...@@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:) (One possible output:)
Replica#0: [tensor([8, 0])] Replica#0: [tensor([1, 3])]
Replica#1: [tensor([ 1, 13])] Replica#1: [tensor([7, 5])]
Replica#2: [tensor([10, 6])] Replica#2: [tensor([11, 9])]
Replica#3: [tensor([ 3, 11])] Replica#3: [tensor([13, 14])]
""" """
def __init__( def __init__(
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
from .internal import * from .internal import *
from .sample_utils import * from .sample_utils import *
from .datapipe_utils import * from .datapipe_utils import *
from .item_sampler_utils import *
"""Utility functions for DistributedItemSampler."""
def count_split(total, num_workers, worker_id, batch_size=1):
"""Calculate the number of assigned items after splitting them by batch
size evenly. It will return the number for this worker and also a sum of
previous workers.
"""
quotient, remainder = divmod(total, num_workers * batch_size)
if batch_size == 1:
assigned = quotient + (worker_id < remainder)
else:
batch_count, last_batch = divmod(remainder, batch_size)
assigned = quotient * batch_size + (
batch_size
if worker_id < batch_count
else (last_batch if worker_id == batch_count else 0)
)
prefix_sum = quotient * worker_id * batch_size + min(
worker_id * batch_size, remainder
)
return (assigned, prefix_sum)
def calculate_range(
distributed,
total,
num_replicas,
rank,
num_workers,
worker_id,
batch_size,
drop_last,
drop_uneven_inputs,
):
"""Calculates the range of items to be assigned to the current worker.
This function evenly distributes `total` items among multiple workers,
batching them using `batch_size`. Each replica has `num_workers` workers.
The batches generated by workers within the same replica are combined into
the replica`s output. The `drop_last` parameter determines whether
incomplete batches should be dropped. If `drop_last` is True, incomplete
batches are discarded. The `drop_uneven_inputs` parameter determines if the
number of batches assigned to each replica should be the same. If
`drop_uneven_inputs` is True, excessive batches for some replicas will be
dropped.
Args:
distributed (bool): Whether it's in distributed mode.
total (int): The total number of items.
num_replicas (int): The total number of replicas.
rank (int): The rank of the current replica.
num_workers (int): The number of workers per replica.
worker_id (int): The ID of the current worker.
batch_size (int): The desired batch size.
drop_last (bool): Whether to drop incomplete batches.
drop_uneven_inputs (bool): Whether to drop excessive batches for some
replicas.
Returns:
tuple: A tuple containing three numbers:
- start_offset (int): The starting offset of the range assigned to
the current worker.
- assigned_count (int): The length of the range assigned to the
current worker.
- output_count (int): The number of items that the current worker
will produce after dropping.
"""
# Check if it's distributed mode.
if not distributed:
if not drop_last:
return (0, total, total)
else:
return (0, total, total // batch_size * batch_size)
# First, equally distribute items into all replicas.
assigned_count, start_offset = count_split(
total, num_replicas, rank, batch_size
)
# Calculate the number of outputs when drop_uneven_inputs is True.
# `assigned_count` is the number of items distributed to the current
# process. `output_count` is the number of items should be output
# by this process after dropping.
if not drop_uneven_inputs:
if not drop_last:
output_count = assigned_count
else:
output_count = assigned_count // batch_size * batch_size
else:
if not drop_last:
min_item_count, _ = count_split(
total, num_replicas, num_replicas - 1, batch_size
)
min_batch_count = (min_item_count + batch_size - 1) // batch_size
output_count = min(min_batch_count * batch_size, assigned_count)
else:
output_count = total // (batch_size * num_replicas) * batch_size
# If there are multiple workers, equally distribute the batches to
# all workers.
if num_workers > 1:
# Equally distribute the dropped number too.
dropped_items, prev_dropped_items = count_split(
assigned_count - output_count, num_workers, worker_id
)
output_count, prev_output_count = count_split(
output_count,
num_workers,
worker_id,
batch_size,
)
assigned_count = output_count + dropped_items
start_offset += prev_output_count + prev_dropped_items
return (start_offset, assigned_count, output_count)
...@@ -728,8 +728,8 @@ def distributed_item_sampler_subprocess( ...@@ -728,8 +728,8 @@ def distributed_item_sampler_subprocess(
nprocs, nprocs,
item_set, item_set,
num_ids, num_ids,
num_workers,
batch_size, batch_size,
shuffle,
drop_last, drop_last,
drop_uneven_inputs, drop_uneven_inputs,
): ):
...@@ -750,7 +750,7 @@ def distributed_item_sampler_subprocess( ...@@ -750,7 +750,7 @@ def distributed_item_sampler_subprocess(
item_sampler = gb.DistributedItemSampler( item_sampler = gb.DistributedItemSampler(
item_set, item_set,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=True,
drop_last=drop_last, drop_last=drop_last,
drop_uneven_inputs=drop_uneven_inputs, drop_uneven_inputs=drop_uneven_inputs,
) )
...@@ -759,7 +759,9 @@ def distributed_item_sampler_subprocess( ...@@ -759,7 +759,9 @@ def distributed_item_sampler_subprocess(
gb.BasicFeatureStore({}), gb.BasicFeatureStore({}),
[], [],
) )
data_loader = gb.SingleProcessDataLoader(feature_fetcher) data_loader = gb.MultiProcessDataLoader(
feature_fetcher, num_workers=num_workers
)
# Count the numbers of items and batches. # Count the numbers of items and batches.
num_items = 0 num_items = 0
...@@ -788,12 +790,104 @@ def distributed_item_sampler_subprocess( ...@@ -788,12 +790,104 @@ def distributed_item_sampler_subprocess(
dist.destroy_process_group() dist.destroy_process_group()
@pytest.mark.parametrize(
"params",
[
((24, 4, 0, 4, False, False), [(8, 8), (8, 8), (4, 4), (4, 4)]),
((30, 4, 0, 4, False, False), [(8, 8), (8, 8), (8, 8), (6, 6)]),
((30, 4, 0, 4, True, False), [(8, 8), (8, 8), (8, 8), (6, 4)]),
((30, 4, 0, 4, False, True), [(8, 8), (8, 8), (8, 8), (6, 6)]),
((30, 4, 0, 4, True, True), [(8, 4), (8, 4), (8, 4), (6, 4)]),
(
(53, 4, 2, 4, False, False),
[(8, 8), (8, 8), (8, 8), (5, 5), (8, 8), (4, 4), (8, 8), (4, 4)],
),
(
(53, 4, 2, 4, True, False),
[(8, 8), (8, 8), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
),
(
(53, 4, 2, 4, False, True),
[(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
),
(
(53, 4, 2, 4, True, True),
[(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
),
(
(63, 4, 2, 4, False, False),
[(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],
),
(
(63, 4, 2, 4, True, False),
[(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (10, 8), (5, 4)],
),
(
(63, 4, 2, 4, False, True),
[(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],
),
(
(63, 4, 2, 4, True, True),
[
(10, 8),
(6, 4),
(10, 8),
(6, 4),
(10, 8),
(6, 4),
(10, 8),
(5, 4),
],
),
(
(65, 4, 2, 4, False, False),
[(9, 9), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],
),
(
(65, 4, 2, 4, True, True),
[(9, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],
),
],
)
def test_RangeCalculation(params):
(
(
total,
num_replicas,
num_workers,
batch_size,
drop_last,
drop_uneven_inputs,
),
key,
) = params
answer = []
sum = 0
for rank in range(num_replicas):
for worker_id in range(max(num_workers, 1)):
result = gb.utils.calculate_range(
True,
total,
num_replicas,
rank,
num_workers,
worker_id,
batch_size,
drop_last,
drop_uneven_inputs,
)
assert sum == result[0]
sum += result[1]
answer.append((result[1], result[2]))
assert key == answer
@pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36]) @pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36])
@pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.parametrize("num_workers", [0, 2])
@pytest.mark.parametrize("drop_last", [False, True]) @pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("drop_uneven_inputs", [False, True]) @pytest.mark.parametrize("drop_uneven_inputs", [False, True])
def test_DistributedItemSampler( def test_DistributedItemSampler(
num_ids, shuffle, drop_last, drop_uneven_inputs num_ids, num_workers, drop_last, drop_uneven_inputs
): ):
nprocs = 4 nprocs = 4
batch_size = 4 batch_size = 4
...@@ -813,8 +907,8 @@ def test_DistributedItemSampler( ...@@ -813,8 +907,8 @@ def test_DistributedItemSampler(
nprocs, nprocs,
item_set, item_set,
num_ids, num_ids,
num_workers,
batch_size, batch_size,
shuffle,
drop_last, drop_last,
drop_uneven_inputs, drop_uneven_inputs,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment