Commit 2c9b3e59 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix DDP training in HuBERT recipes (#3068)

Summary:
The `BucketizeBatchSampler` may return different iter_list in different node if `shuffle` is `True`, which will cause DPP training hang forever.
`shuffle` in `DistributedSampler` only happens in initialization, which means it will assign the same subset to replicas in all training epochs. The PR fixes the two above issues.

cc arlofaria

Pull Request resolved: https://github.com/pytorch/audio/pull/3068

Reviewed By: mthrok

Differential Revision: D43372110

Pulled By: nateanl

fbshipit-source-id: a162728406ae995e05d2a07cfc2444fb76cf345e
parent 11bdafc3
......@@ -32,6 +32,7 @@ class BucketizeBatchSampler(BatchSampler):
(Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default: True)
seed (int, optional): The seed for initialzing RNG. Only used when `shuffle` is True. (Default: 0)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
......@@ -45,7 +46,7 @@ class BucketizeBatchSampler(BatchSampler):
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer to enable shuffling every epoch.
in pytorch_lightning Trainer and set ``seed`` to ``self.trainer.current_epoch`` to enable shuffling every epoch.
"""
def __init__(
......@@ -57,6 +58,7 @@ class BucketizeBatchSampler(BatchSampler):
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if max_len is None:
......@@ -82,6 +84,10 @@ class BucketizeBatchSampler(BatchSampler):
self.max_token_count = max_token_count
self.batch_size = batch_size
self.shuffle = shuffle
self.seed = seed
if self.shuffle:
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self._update_iter_list()
......@@ -115,7 +121,7 @@ class BucketizeBatchSampler(BatchSampler):
def _update_iter_list(self) -> None:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0), generator=self.g)]
self.iter_list = []
total_len = 0
batch = []
......@@ -193,7 +199,18 @@ class DistributedBatchSampler(DistributedSampler):
self.epoch = 0
self.seed = seed
self.drop_last = drop_last
if shuffle:
self.shuffle = shuffle
indices = self.batch_sampler.iter_list
if self.drop_last and len(indices) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
......@@ -210,7 +227,6 @@ class DistributedBatchSampler(DistributedSampler):
self.subset = indices[self.rank : self.total_size : self.num_replicas]
assert len(self.subset) == self.num_samples
def __iter__(self):
return iter(self.subset)
def __len__(self):
......
......@@ -287,7 +287,8 @@ class HuBERTPreTrainModule(LightningModule):
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
shuffle=True,
seed=self.trainer.current_epoch,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
......@@ -508,7 +509,11 @@ class HuBERTFineTuneModule(LightningModule):
dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True
lengths,
num_buckets=100,
max_token_count=self.seconds_per_batch * 16000,
shuffle=True,
seed=self.global_step,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.global_step)
......
......@@ -34,6 +34,7 @@ class HuBERTDataModule(LightningDataModule):
min_len=32000,
max_len=250000,
shuffle=True,
seed=self.trainer.current_epoch,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
......
......@@ -28,6 +28,7 @@ class BucketizeBatchSampler(BatchSampler):
(Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default: True)
seed (int, optional): The seed for initialzing RNG. Only used when `shuffle` is True. (Default: 0)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
......@@ -41,7 +42,7 @@ class BucketizeBatchSampler(BatchSampler):
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer to enable shuffling every epoch.
in pytorch_lightning Trainer and set ``seed`` to ``self.trainer.current_epoch`` to enable shuffling every epoch.
"""
def __init__(
......@@ -53,6 +54,7 @@ class BucketizeBatchSampler(BatchSampler):
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if max_len is None:
......@@ -78,6 +80,10 @@ class BucketizeBatchSampler(BatchSampler):
self.max_token_count = max_token_count
self.batch_size = batch_size
self.shuffle = shuffle
self.seed = seed
if self.shuffle:
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self._update_iter_list()
......@@ -111,7 +117,7 @@ class BucketizeBatchSampler(BatchSampler):
def _update_iter_list(self) -> None:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0), generator=self.g)]
self.iter_list = []
total_len = 0
batch = []
......@@ -189,7 +195,18 @@ class DistributedBatchSampler(DistributedSampler):
self.epoch = 0
self.seed = seed
self.drop_last = drop_last
if shuffle:
self.shuffle = shuffle
indices = self.batch_sampler.iter_list
if self.drop_last and len(indices) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
......@@ -206,7 +223,6 @@ class DistributedBatchSampler(DistributedSampler):
self.subset = indices[self.rank : self.total_size : self.num_replicas]
assert len(self.subset) == self.num_samples
def __iter__(self):
return iter(self.subset)
def __len__(self):
......
......@@ -140,7 +140,7 @@ def run_train(args):
dataset_path=args.dataset_path,
dataset="librispeech",
feature_type="mfcc",
seconds_per_batch=200,
seconds_per_batch=args.seconds_per_batch,
train_shuffle=True,
num_workers=10,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment