Commit 3bfbb49b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Clean up sharded train iterator

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/586

Differential Revision: D15372949

Pulled By: myleott

fbshipit-source-id: c1cf1c645e8d55fc8568f23a47c45677ac9ab1da
parent fca32e05
...@@ -87,9 +87,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -87,9 +87,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk) os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task): def load_checkpoint(args, trainer):
"""Load a checkpoint and replay dataloader to match.""" """Load a checkpoint and restore the training iterator."""
# Only rank 0 should attempt to create the required dir # only one worker should attempt to create the required dir
if args.distributed_rank == 0: if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
...@@ -97,32 +97,26 @@ def load_checkpoint(args, trainer, epoch_itr, max_positions, task): ...@@ -97,32 +97,26 @@ def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
checkpoint_path = args.restore_file checkpoint_path = args.restore_file
else: else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler, extra_state = trainer.load_checkpoint(
eval(args.optimizer_overrides)) checkpoint_path,
if extra_state is not None: args.reset_optimizer,
# replay train iterator to match checkpoint args.reset_lr_scheduler,
epoch_itr_state = extra_state['train_iterator'] eval(args.optimizer_overrides),
)
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded. if extra_state is not None and 'best' in extra_state and not args.reset_optimizer:
if epoch_itr_state['epoch'] != 0: save_checkpoint.best = extra_state['best']
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
if extra_state is not None:
epoch_itr.load_state_dict(epoch_itr_state) # restore iterator from checkpoint
itr_state = extra_state['train_iterator']
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) epoch_itr.load_state_dict(itr_state)
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
else: else:
print('| no existing checkpoint found {}'.format(checkpoint_path)) epoch_itr = trainer.get_train_iterator(epoch=0)
return False
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path): def load_checkpoint_to_cpu(path):
...@@ -165,28 +159,6 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): ...@@ -165,28 +159,6 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
return ensemble, args return ensemble, args
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory. """Retrieves all checkpoints found in `path` directory.
......
...@@ -76,7 +76,8 @@ class EpochBatchIterator(object): ...@@ -76,7 +76,8 @@ class EpochBatchIterator(object):
num_workers (int, optional): how many subprocesses to use for data num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process loading. 0 means the data will be loaded in the main process
(default: 0). (default: 0).
epoch (int, optional): The epoch to start the iterator from. epoch (int, optional): the epoch to start the iterator from
(default: 0).
""" """
def __init__( def __init__(
......
...@@ -118,7 +118,8 @@ class FairseqTask(object): ...@@ -118,7 +118,8 @@ class FairseqTask(object):
num_workers (int, optional): how many subprocesses to use for data num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process loading. 0 means the data will be loaded in the main process
(default: 0). (default: 0).
epoch (int, optional): The epoch to start the iterator from. epoch (int, optional): the epoch to start the iterator from
(default: 0).
Returns: Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
......
...@@ -124,8 +124,9 @@ class Trainer(object): ...@@ -124,8 +124,9 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
checkpoint_utils.save_state( checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer, filename, self.args, self.get_model().state_dict(), self.criterion,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.optimizer, self.lr_scheduler, self._num_updates,
self._optim_history, extra_state,
) )
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
...@@ -165,17 +166,48 @@ class Trainer(object): ...@@ -165,17 +166,48 @@ class Trainer(object):
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
if extra_state is not None and 'train_meters' in extra_state: if extra_state is not None:
self.meters.update(extra_state['train_meters']) epoch = extra_state['train_iterator']['epoch']
del extra_state['train_meters'] print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
filename, epoch, self.get_num_updates()))
# reset TimeMeters, since their start times don't make sense anymore self.lr_step(epoch)
for meter in self.meters.values(): self.lr_step_update(self.get_num_updates())
if isinstance(meter, TimeMeter):
meter.reset() if 'train_meters' in extra_state:
self.meters.update(extra_state['train_meters'])
del extra_state['train_meters']
# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()
else:
print('| no existing checkpoint found {}'.format(filename))
return extra_state return extra_state
def get_train_iterator(self, epoch, combine=True):
"""Return an EpochBatchIterator over the training set for a given epoch."""
print('| loading train data for epoch {}'.format(epoch))
self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset),
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.args.distributed_world_size,
shard_id=self.args.distributed_rank,
num_workers=self.args.num_workers,
epoch=epoch,
)
def train_step(self, samples, dummy_batch=False, raise_oom=False): def train_step(self, samples, dummy_batch=False, raise_oom=False):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
if self._dummy_batch is None: if self._dummy_batch is None:
......
...@@ -69,10 +69,9 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -69,10 +69,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr): _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50) self.assertEqual(epoch_itr.iterations_in_epoch, 50)
...@@ -99,10 +98,9 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -99,10 +98,9 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr): _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.epoch, 3)
...@@ -112,9 +110,10 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -112,9 +110,10 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_no_checkpoint(self): def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches['os.path.isfile'].return_value = False self.patches['os.path.isfile'].return_value = False
checkpoint_utils.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.epoch, 1)
......
...@@ -64,27 +64,9 @@ def main(args, init_distributed=False): ...@@ -64,27 +64,9 @@ def main(args, init_distributed=False):
args.max_sentences, args.max_sentences,
)) ))
max_positions = utils.resolve_max_positions( # Load the latest checkpoint if one is available and restore the
task.max_positions(), # corresponding train iterator
model.max_positions(), extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
)
# Initialize dataloader
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
)
# Load the latest checkpoint if one is available
checkpoint_utils.load_checkpoint(
args, trainer, epoch_itr, max_positions, task)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -106,10 +88,11 @@ def main(args, init_distributed=False): ...@@ -106,10 +88,11 @@ def main(args, init_distributed=False):
# save checkpoint # save checkpoint
if epoch_itr.epoch % args.save_interval == 0: if epoch_itr.epoch % args.save_interval == 0:
checkpoint_utils.save_checkpoint( checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
args, trainer, epoch_itr, valid_losses[0])
epoch_itr = checkpoint_utils.reload_train(args, epoch_itr, max_positions, task) if ':' in args.data:
# sharded data: get train iterator for next epoch
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment