"docs/source/vscode:/vscode.git/clone" did not exist on "650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583"
Unverified Commit b08c446d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] Improve ItemSampler via indexing (#6453)

parent 5f327ff4
...@@ -4,6 +4,8 @@ from collections.abc import Mapping ...@@ -4,6 +4,8 @@ from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Callable, Iterator, Optional from typing import Callable, Iterator, Optional
import numpy as np
import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
...@@ -77,6 +79,72 @@ def minibatcher_default(batch, names): ...@@ -77,6 +79,72 @@ def minibatcher_default(batch, names):
return minibatch return minibatch
class ItemShufflerAndBatcher:
"""A shuffler to shuffle items and create batches.
This class is used internally by :class:`ItemSampler` to shuffle items and
create batches. It is not supposed to be used directly. The intention of
this class is to avoid time-consuming iteration over :class:`ItemSet`. As
an optimization, it slices from the :class:`ItemSet` via indexing first,
then shuffle and create batches.
Parameters
----------
item_set : ItemSet
Data to be iterated.
shuffle : bool
Option to shuffle before batching.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
buffer_size : int
The size of the buffer to store items sliced from the :class:`ItemSet`.
"""
def __init__(
self,
item_set: ItemSet,
shuffle: bool,
batch_size: int,
drop_last: bool,
buffer_size: Optional[int] = 10 * 1000,
):
self._item_set = item_set
self._shuffle = shuffle
self._batch_size = batch_size
self._drop_last = drop_last
self._buffer_size = max(buffer_size, 20 * batch_size)
# Round up the buffer size to the nearest multiple of batch size.
self._buffer_size = (
(self._buffer_size + batch_size - 1) // batch_size * batch_size
)
def __iter__(self):
buffer = None
num_items = len(self._item_set)
start = 0
while start < num_items:
end = min(start + self._buffer_size, num_items)
buffer = self._item_set[start:end]
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices):
break
batch_indices = indices[i : i + self._batch_size]
if len(self._item_set._items) == 1:
if isinstance(buffer[0], DGLGraph):
yield dgl_batch([buffer[idx] for idx in batch_indices])
else:
yield buffer[batch_indices]
else:
yield tuple(item[batch_indices] for item in buffer)
buffer = None
start = end
class ItemSampler(IterDataPipe): class ItemSampler(IterDataPipe):
"""A sampler to iterate over input items and create subsets. """A sampler to iterate over input items and create subsets.
...@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe): ...@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe):
minibatcher: Optional[Callable] = minibatcher_default, minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
# [TODO][Rui] For now, it's a temporary knob to disable indexing. In
# the future, we will enable indexing for all the item sets.
use_indexing: Optional[bool] = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self._names = item_set.names self._names = item_set.names
self._item_set = IterableWrapper(item_set) # Check if the item set supports indexing.
try:
item_set[0]
except TypeError:
use_indexing = False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing = (not isinstance(item_set, ItemSetDict)) and use_indexing
self._use_indexing = use_indexing
self._item_set = (
item_set if self._use_indexing else IterableWrapper(item_set)
)
self._batch_size = batch_size self._batch_size = batch_size
self._minibatcher = minibatcher self._minibatcher = minibatcher
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
self._use_indexing = use_indexing
def _organize_items(self, data_pipe) -> None: def _organize_items(self, data_pipe) -> None:
# Shuffle before batch. # Shuffle before batch.
...@@ -333,6 +415,16 @@ class ItemSampler(IterDataPipe): ...@@ -333,6 +415,16 @@ class ItemSampler(IterDataPipe):
return default_collate(batch) return default_collate(batch)
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if self._use_indexing:
data_pipe = IterableWrapper(
ItemShufflerAndBatcher(
self._item_set,
self._shuffle,
self._batch_size,
self._drop_last,
)
)
else:
# Organize items. # Organize items.
data_pipe = self._organize_items(self._item_set) data_pipe = self._organize_items(self._item_set)
...@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler): ...@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler):
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
drop_uneven_inputs: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False,
) -> None: ) -> None:
super().__init__(item_set, batch_size, minibatcher, drop_last, shuffle) # [TODO][Rui] For now, always set use_indexing to False.
super().__init__(
item_set,
batch_size,
minibatcher,
drop_last,
shuffle,
use_indexing=False,
)
self._drop_uneven_inputs = drop_uneven_inputs self._drop_uneven_inputs = drop_uneven_inputs
# Apply a sharding filter to distribute the items. # Apply a sharding filter to distribute the items.
self._item_set = self._item_set.sharding_filter() self._item_set = self._item_set.sharding_filter()
......
...@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher(): ...@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher():
assert len(minibatch.seed_nodes) == 4 assert len(minibatch.seed_nodes) == 4
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
num_ids = 103
class InvalidLength:
def __iter__(self):
return iter(torch.arange(0, num_ids))
seed_nodes = gb.ItemSet(InvalidLength())
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment