Commit b6e001f6 authored by Alex Xiao's avatar Alex Xiao Committed by Facebook Github Bot
Browse files

Fix data loading memory issue in pyspeech

Summary:
We currently shard data when creating the batch iterator. This means we first load all indicese/frame lengths/handles into memory, and then do the sharding. This makes it impossible to train on large datasets with a high amount of workers  because each worker will need to load the entire dataset into memory. For training on a million hours of data (i.e. semi-supervised or unsupervised approaches) this data loading just makes it flat out impossible to use 8 GPU's.

3 changes:

1. This diff modifies the data loading such that we do the sharding while we read the handles file, rather than later. This modification is done on a task-by-task basis, since the task specifies how the data is loaded. I've tried to make the code compatible with both sharding during handle loading and sharding during batch iteration. I've currently only done the sharding during handle loading for the aligned_training task.

2. To support data sharding at data loading time and the requirement that all shards must have exactly the same # of batches, I've added a method to do this synchronization where all shards with too many batches would just truncate the extra ones, similar to what we already do.

2. In fairspeq/train.py, we are actually loading the training dataset and batch iterator twice, once in train.py and once when loading the checkpoint (which we always do regardless if there is a checkpoint). This means double the loading time which can be painful for very large files. I've removed the extraneous loading in this diff as well.

Reviewed By: yqwangustc

Differential Revision: D17750715

fbshipit-source-id: 0e6e3d363525fa5661f1c784303390ea13f46377
parent 63b6b3f4
...@@ -90,7 +90,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -90,7 +90,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk) os.remove(old_chk)
def load_checkpoint(args, trainer): def load_checkpoint(args, trainer, data_selector=None):
"""Load a checkpoint and restore the training iterator.""" """Load a checkpoint and restore the training iterator."""
# only one worker should attempt to create the required dir # only one worker should attempt to create the required dir
if args.distributed_rank == 0: if args.distributed_rank == 0:
...@@ -120,10 +120,10 @@ def load_checkpoint(args, trainer): ...@@ -120,10 +120,10 @@ def load_checkpoint(args, trainer):
if extra_state is not None and not args.reset_dataloader: if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint # restore iterator from checkpoint
itr_state = extra_state['train_iterator'] itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True) epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True, data_selector=data_selector)
epoch_itr.load_state_dict(itr_state) epoch_itr.load_state_dict(itr_state)
else: else:
epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True) epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True, data_selector=data_selector)
trainer.lr_step(epoch_itr.epoch) trainer.lr_step(epoch_itr.epoch)
......
...@@ -34,6 +34,8 @@ class CountingIterator(object): ...@@ -34,6 +34,8 @@ class CountingIterator(object):
def __iter__(self): def __iter__(self):
for x in self.iterable: for x in self.iterable:
if self.count >= self.len:
return
self.count += 1 self.count += 1
yield x yield x
...@@ -49,6 +51,12 @@ class CountingIterator(object): ...@@ -49,6 +51,12 @@ class CountingIterator(object):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self return self
def take(self, n):
"""
Truncates the iterator to n elements at most.
"""
self.len = min(self.len, n)
class EpochBatchIterating(object): class EpochBatchIterating(object):
def __len__(self) -> int: def __len__(self) -> int:
......
...@@ -225,11 +225,16 @@ class Trainer(object): ...@@ -225,11 +225,16 @@ class Trainer(object):
return extra_state return extra_state
def get_train_iterator(self, epoch, combine=True, load_dataset=True): def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None):
"""Return an EpochBatchIterator over the training set for a given epoch.""" """Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset: if load_dataset:
print('| loading train data for epoch {}'.format(epoch)) print('| loading train data for epoch {}'.format(epoch))
self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine) self.task.load_dataset(
self.args.train_subset,
epoch=epoch,
combine=combine,
data_selector=data_selector,
)
return self.task.get_batch_iterator( return self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset), dataset=self.task.dataset(self.args.train_subset),
max_tokens=self.args.max_tokens, max_tokens=self.args.max_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment