Unverified Commit 03ca11f5 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix random generator for shuffle among all workers (#6982)

parent 351c860a
...@@ -115,6 +115,8 @@ class ItemShufflerAndBatcher: ...@@ -115,6 +115,8 @@ class ItemShufflerAndBatcher:
rank : int rank : int
The rank of the current replica. Applies only when `distributed` is The rank of the current replica. Applies only when `distributed` is
True. True.
rng : np.random.Generator
The random number generator to use for shuffling.
""" """
def __init__( def __init__(
...@@ -128,6 +130,7 @@ class ItemShufflerAndBatcher: ...@@ -128,6 +130,7 @@ class ItemShufflerAndBatcher:
drop_uneven_inputs: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False,
world_size: Optional[int] = 1, world_size: Optional[int] = 1,
rank: Optional[int] = 0, rank: Optional[int] = 0,
rng: Optional[np.random.Generator] = None,
): ):
self._item_set = item_set self._item_set = item_set
self._shuffle = shuffle self._shuffle = shuffle
...@@ -142,6 +145,7 @@ class ItemShufflerAndBatcher: ...@@ -142,6 +145,7 @@ class ItemShufflerAndBatcher:
self._drop_uneven_inputs = drop_uneven_inputs self._drop_uneven_inputs = drop_uneven_inputs
self._num_replicas = world_size self._num_replicas = world_size
self._rank = rank self._rank = rank
self._rng = rng
def _collate_batch(self, buffer, indices, offsets=None): def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only.""" """Collate a batch from the buffer. For internal use only."""
...@@ -216,7 +220,7 @@ class ItemShufflerAndBatcher: ...@@ -216,7 +220,7 @@ class ItemShufflerAndBatcher:
buffer = self._item_set[start_offset + start : start_offset + end] buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start) indices = torch.arange(end - start)
if self._shuffle: if self._shuffle:
np.random.shuffle(indices.numpy()) self._rng.shuffle(indices.numpy())
offsets = self._calculate_offsets(buffer) offsets = self._calculate_offsets(buffer)
for i in range(0, len(indices), self._batch_size): for i in range(0, len(indices), self._batch_size):
if output_count <= 0: if output_count <= 0:
...@@ -494,6 +498,7 @@ class ItemSampler(IterDataPipe): ...@@ -494,6 +498,7 @@ class ItemSampler(IterDataPipe):
self._drop_uneven_inputs = False self._drop_uneven_inputs = False
self._world_size = None self._world_size = None
self._rank = None self._rank = None
self._rng = np.random.default_rng()
def _organize_items(self, data_pipe) -> None: def _organize_items(self, data_pipe) -> None:
# Shuffle before batch. # Shuffle before batch.
...@@ -529,6 +534,7 @@ class ItemSampler(IterDataPipe): ...@@ -529,6 +534,7 @@ class ItemSampler(IterDataPipe):
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if self._use_indexing: if self._use_indexing:
seed = self._rng.integers(0, np.iinfo(np.int32).max)
data_pipe = IterableWrapper( data_pipe = IterableWrapper(
ItemShufflerAndBatcher( ItemShufflerAndBatcher(
self._item_set, self._item_set,
...@@ -540,6 +546,7 @@ class ItemSampler(IterDataPipe): ...@@ -540,6 +546,7 @@ class ItemSampler(IterDataPipe):
drop_uneven_inputs=self._drop_uneven_inputs, drop_uneven_inputs=self._drop_uneven_inputs,
world_size=self._world_size, world_size=self._world_size,
rank=self._rank, rank=self._rank,
rng=np.random.default_rng(seed),
) )
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment