"vscode:/vscode.git/clone" did not exist on "a353c46ec094d35e0d9863db1c9d41de25744a88"
Commit 99fbd317 authored by Alex Xiao's avatar Alex Xiao Committed by Facebook Github Bot
Browse files

Refactor data sharding to be specified via caller of task rather than task itself

Summary: Modifying number of shards internally to disable data sharding for batch iteration is dangerous because the caller of these tasks is not limited to fairspeq/train. So therefore we should put the onus of data sharding properly on the caller rather than the task itself.

Reviewed By: myleott

Differential Revision: D18456424

fbshipit-source-id: d46be16c441c50082f9a768d0b259e6c28a4b67b
parent 51eb9802
......@@ -94,8 +94,13 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk)
def load_checkpoint(args, trainer, data_selector=None):
"""Load a checkpoint and restore the training iterator."""
def load_checkpoint(args, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
# only one worker should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
......@@ -124,10 +129,18 @@ def load_checkpoint(args, trainer, data_selector=None):
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True, data_selector=data_selector)
epoch_itr = trainer.get_train_iterator(
epoch=itr_state['epoch'],
load_dataset=True,
**passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True, data_selector=data_selector)
epoch_itr = trainer.get_train_iterator(
epoch=0,
load_dataset=True,
**passthrough_args
)
trainer.lr_step(epoch_itr.epoch)
......
......@@ -231,7 +231,7 @@ class Trainer(object):
return extra_state
def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None):
def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset:
print('| loading train data for epoch {}'.format(epoch))
......@@ -252,8 +252,8 @@ class Trainer(object):
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.args.distributed_world_size,
shard_id=self.args.distributed_rank,
num_shards=self.args.distributed_world_size if shard_batch_itr else 1,
shard_id=self.args.distributed_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
epoch=epoch,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment