Commit 6411c9ad authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Introduce DistributedBatchSampler (#2299)

Summary:
When using customized `batch_sampler`, pytorch_lightning can't wrap the distributed sampler onto it. Hence we provide a `DistributedBatchSampler` that supports `BucketizeBatchSampler` in `ddp` mode.

The `DistributedBatchSampler` assumes `BucketizeBatchSampler.iter_list` is a list of lists, where each sub-list contains a batch of indices. Setting `shuffle` to `True` will shuffle the lists based on `seed` and current `epoch`.

The `shuffle` only happens in the initialization, and won't be changed if user don't reset it. The reason is shuffling `BucketizeBatchSampler` may have a different length than before, do shuffling in ``__iter__`` may result in mismatch between ``__len__`` and the real length value.
Hence users need to set `reload_dataloaders_every_n_epochs=1` in pytorch_lightning's Trainer. Then the value of ``__len__``  and the real length is the same.

Pull Request resolved: https://github.com/pytorch/audio/pull/2299

Reviewed By: hwangjeff

Differential Revision: D35781538

Pulled By: nateanl

fbshipit-source-id: 6e8396615497f1aeddab1ee5678830c0445c2b2a
parent 2acafdaf
...@@ -3,9 +3,10 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union ...@@ -3,9 +3,10 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import BatchSampler, Dataset from torch.utils.data import BatchSampler, Dataset, DistributedSampler
class BucketizeBatchSampler(BatchSampler): class BucketizeBatchSampler(BatchSampler):
...@@ -34,6 +35,10 @@ class BucketizeBatchSampler(BatchSampler): ...@@ -34,6 +35,10 @@ class BucketizeBatchSampler(BatchSampler):
Note: Note:
``drop_last`` is only valid when ``batch_size`` argument is given. ``drop_last`` is only valid when ``batch_size`` argument is given.
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer to enable shuffling every epoch.
""" """
def __init__( def __init__(
...@@ -72,7 +77,6 @@ class BucketizeBatchSampler(BatchSampler): ...@@ -72,7 +77,6 @@ class BucketizeBatchSampler(BatchSampler):
self.shuffle = shuffle self.shuffle = shuffle
self.drop_last = drop_last self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len) self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self.iter_list = []
self._update_iter_list() self._update_iter_list()
def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]: def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
...@@ -102,6 +106,9 @@ class BucketizeBatchSampler(BatchSampler): ...@@ -102,6 +106,9 @@ class BucketizeBatchSampler(BatchSampler):
return buckets return buckets
def _update_iter_list(self) -> None: def _update_iter_list(self) -> None:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self.iter_list = [] self.iter_list = []
total_len = 0 total_len = 0
batch = [] batch = []
...@@ -121,11 +128,6 @@ class BucketizeBatchSampler(BatchSampler): ...@@ -121,11 +128,6 @@ class BucketizeBatchSampler(BatchSampler):
self.iter_list.append(batch) self.iter_list.append(batch)
def __iter__(self) -> Iterator[List[int]]: def __iter__(self) -> Iterator[List[int]]:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self._update_iter_list()
return iter(self.iter_list) return iter(self.iter_list)
def __len__(self): def __len__(self):
...@@ -133,6 +135,81 @@ class BucketizeBatchSampler(BatchSampler): ...@@ -133,6 +135,81 @@ class BucketizeBatchSampler(BatchSampler):
return len(self.iter_list) return len(self.iter_list)
class DistributedBatchSampler(DistributedSampler):
"""`BucketizeBatchSampler` wrapper that distributes across each processor.
Args:
batch_sampler (BucketizeBatchSampler): the initialized bucketize batch sampler.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): if ``True``, the list of batch indices will be shuffled.
(Default: ``True``)
seed (int, optional): random seed used to shuffle the batch_sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. (Default: ``0``)
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. (Default: ``False``)
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer, and set `sampler.set_epoch(self.current_epoch)` before DataLoader initialization
in `train_dataloader` method to enable shuffling every epoch.
"""
def __init__(
self,
batch_sampler: BucketizeBatchSampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
self.batch_sampler = batch_sampler
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.epoch = 0
self.seed = seed
self.drop_last = drop_last
if shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
indices = [self.batch_sampler.iter_list[i] for i in perm]
else:
indices = self.batch_sampler.iter_list
if self.drop_last:
self.total_size = len(indices) - len(indices) % self.num_replicas
else:
padding_size = self.num_replicas - len(indices) % self.num_replicas
indices += indices[:padding_size]
self.total_size = len(indices)
self.num_samples = self.total_size // self.num_replicas
self.subset = indices[self.rank : self.total_size : self.num_replicas]
assert len(self.subset) == self.num_samples
def __iter__(self):
return iter(self.subset)
def __len__(self):
return self.num_samples
class HuBERTDataSet(Dataset): class HuBERTDataSet(Dataset):
"""Create a Dataset for HuBERT model training and fine-tuning. """Create a Dataset for HuBERT model training and fine-tuning.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment