Commit ffc3bb58 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --reset-dataloader

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/613

Differential Revision: D15541384

Pulled By: myleott

fbshipit-source-id: ef2c0b0a51cdf37af2ccff0546f524d49f87e65d
parent 9770f367
...@@ -106,10 +106,15 @@ def load_checkpoint(args, trainer): ...@@ -106,10 +106,15 @@ def load_checkpoint(args, trainer):
reset_meters=args.reset_meters, reset_meters=args.reset_meters,
) )
if extra_state is not None and 'best' in extra_state and not args.reset_optimizer: if (
extra_state is not None
and 'best' in extra_state
and not args.reset_optimizer
and not args.reset_meters
):
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state['best']
if extra_state is not None: if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint # restore iterator from checkpoint
itr_state = extra_state['train_iterator'] itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch']) epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
...@@ -117,6 +122,8 @@ def load_checkpoint(args, trainer): ...@@ -117,6 +122,8 @@ def load_checkpoint(args, trainer):
else: else:
epoch_itr = trainer.get_train_iterator(epoch=0) epoch_itr = trainer.get_train_iterator(epoch=0)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr return extra_state, epoch_itr
......
...@@ -333,14 +333,16 @@ def add_checkpoint_args(parser): ...@@ -333,14 +333,16 @@ def add_checkpoint_args(parser):
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--reset-optimizer', action='store_true', group.add_argument('--reset-dataloader', action='store_true',
help='if set, does not load optimizer state from the checkpoint') help='if set, does not reload dataloader state from the checkpoint')
group.add_argument('--reset-lr-scheduler', action='store_true', group.add_argument('--reset-lr-scheduler', action='store_true',
help='if set, does not load lr scheduler state from the checkpoint') help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--reset-meters', action='store_true', group.add_argument('--reset-meters', action='store_true',
help='if set, does not load meters from the checkpoint') help='if set, does not load meters from the checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
...@@ -125,7 +125,7 @@ class Trainer(object): ...@@ -125,7 +125,7 @@ class Trainer(object):
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
checkpoint_utils.save_state( checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, filename, self.args, self.get_model().state_dict(), self.criterion,
self.optimizer, self.lr_scheduler, self._num_updates, self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state, self._optim_history, extra_state,
) )
...@@ -171,7 +171,7 @@ class Trainer(object): ...@@ -171,7 +171,7 @@ class Trainer(object):
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates'] self.set_num_updates(last_optim['num_updates'])
if extra_state is not None: if extra_state is not None:
epoch = extra_state['train_iterator']['epoch'] epoch = extra_state['train_iterator']['epoch']
...@@ -179,7 +179,6 @@ class Trainer(object): ...@@ -179,7 +179,6 @@ class Trainer(object):
filename, epoch, self.get_num_updates())) filename, epoch, self.get_num_updates()))
self.lr_step(epoch) self.lr_step(epoch)
self.lr_step_update(self.get_num_updates())
if 'train_meters' in extra_state: if 'train_meters' in extra_state:
self.meters.update(extra_state['train_meters']) self.meters.update(extra_state['train_meters'])
...@@ -328,10 +327,7 @@ class Trainer(object): ...@@ -328,10 +327,7 @@ class Trainer(object):
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
self._num_updates += 1 self.set_num_updates(self.get_num_updates() + 1)
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# task specific update per step # task specific update per step
self.task.update_step(self._num_updates) self.task.update_step(self._num_updates)
...@@ -449,11 +445,13 @@ class Trainer(object): ...@@ -449,11 +445,13 @@ class Trainer(object):
def lr_step(self, epoch, val_loss=None): def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss) _lr = self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps
return self.lr_step_update()
def lr_step_update(self, num_updates): def lr_step_update(self):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates) return self.lr_scheduler.step_update(self.get_num_updates())
def get_lr(self): def get_lr(self):
"""Get the current learning rate.""" """Get the current learning rate."""
...@@ -473,6 +471,11 @@ class Trainer(object): ...@@ -473,6 +471,11 @@ class Trainer(object):
"""Get the number of parameters updates.""" """Get the number of parameters updates."""
return self._num_updates return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
self.lr_step_update()
def _prepare_sample(self, sample): def _prepare_sample(self, sample):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
......
...@@ -56,6 +56,9 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -56,6 +56,9 @@ class TestLoadCheckpoint(unittest.TestCase):
def setUp(self): def setUp(self):
self.args_mock = MagicMock() self.args_mock = MagicMock()
self.args_mock.optimizer_overrides = '{}' self.args_mock.optimizer_overrides = '{}'
self.args_mock.reset_dataloader = False
self.args_mock.reset_meters = False
self.args_mock.reset_optimizer = False
self.patches = { self.patches = {
'os.makedirs': MagicMock(), 'os.makedirs': MagicMock(),
'os.path.join': MagicMock(), 'os.path.join': MagicMock(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment