Unverified Commit b08c446d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] Improve ItemSampler via indexing (#6453)

parent 5f327ff4
......@@ -4,6 +4,8 @@ from collections.abc import Mapping
from functools import partial
from typing import Callable, Iterator, Optional
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
......@@ -77,6 +79,72 @@ def minibatcher_default(batch, names):
return minibatch
class ItemShufflerAndBatcher:
"""A shuffler to shuffle items and create batches.
This class is used internally by :class:`ItemSampler` to shuffle items and
create batches. It is not supposed to be used directly. The intention of
this class is to avoid time-consuming iteration over :class:`ItemSet`. As
an optimization, it slices from the :class:`ItemSet` via indexing first,
then shuffle and create batches.
Parameters
----------
item_set : ItemSet
Data to be iterated.
shuffle : bool
Option to shuffle before batching.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
buffer_size : int
The size of the buffer to store items sliced from the :class:`ItemSet`.
"""
def __init__(
self,
item_set: ItemSet,
shuffle: bool,
batch_size: int,
drop_last: bool,
buffer_size: Optional[int] = 10 * 1000,
):
self._item_set = item_set
self._shuffle = shuffle
self._batch_size = batch_size
self._drop_last = drop_last
self._buffer_size = max(buffer_size, 20 * batch_size)
# Round up the buffer size to the nearest multiple of batch size.
self._buffer_size = (
(self._buffer_size + batch_size - 1) // batch_size * batch_size
)
def __iter__(self):
buffer = None
num_items = len(self._item_set)
start = 0
while start < num_items:
end = min(start + self._buffer_size, num_items)
buffer = self._item_set[start:end]
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices):
break
batch_indices = indices[i : i + self._batch_size]
if len(self._item_set._items) == 1:
if isinstance(buffer[0], DGLGraph):
yield dgl_batch([buffer[idx] for idx in batch_indices])
else:
yield buffer[batch_indices]
else:
yield tuple(item[batch_indices] for item in buffer)
buffer = None
start = end
class ItemSampler(IterDataPipe):
"""A sampler to iterate over input items and create subsets.
......@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe):
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
# [TODO][Rui] For now, it's a temporary knob to disable indexing. In
# the future, we will enable indexing for all the item sets.
use_indexing: Optional[bool] = True,
) -> None:
super().__init__()
self._names = item_set.names
self._item_set = IterableWrapper(item_set)
# Check if the item set supports indexing.
try:
item_set[0]
except TypeError:
use_indexing = False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing = (not isinstance(item_set, ItemSetDict)) and use_indexing
self._use_indexing = use_indexing
self._item_set = (
item_set if self._use_indexing else IterableWrapper(item_set)
)
self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last
self._shuffle = shuffle
self._use_indexing = use_indexing
def _organize_items(self, data_pipe) -> None:
# Shuffle before batch.
......@@ -333,6 +415,16 @@ class ItemSampler(IterDataPipe):
return default_collate(batch)
def __iter__(self) -> Iterator:
if self._use_indexing:
data_pipe = IterableWrapper(
ItemShufflerAndBatcher(
self._item_set,
self._shuffle,
self._batch_size,
self._drop_last,
)
)
else:
# Organize items.
data_pipe = self._organize_items(self._item_set)
......@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler):
num_replicas: Optional[int] = None,
drop_uneven_inputs: Optional[bool] = False,
) -> None:
super().__init__(item_set, batch_size, minibatcher, drop_last, shuffle)
# [TODO][Rui] For now, always set use_indexing to False.
super().__init__(
item_set,
batch_size,
minibatcher,
drop_last,
shuffle,
use_indexing=False,
)
self._drop_uneven_inputs = drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self._item_set = self._item_set.sharding_filter()
......
......@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher():
assert len(minibatch.seed_nodes) == 4
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
num_ids = 103
class InvalidLength:
def __iter__(self):
return iter(torch.arange(0, num_ids))
seed_nodes = gb.ItemSet(InvalidLength())
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment