Commit ffc3bb58 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --reset-dataloader

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/613

Differential Revision: D15541384

Pulled By: myleott

fbshipit-source-id: ef2c0b0a51cdf37af2ccff0546f524d49f87e65d
parent 9770f367
......@@ -106,10 +106,15 @@ def load_checkpoint(args, trainer):
reset_meters=args.reset_meters,
)
if extra_state is not None and 'best' in extra_state and not args.reset_optimizer:
if (
extra_state is not None
and 'best' in extra_state
and not args.reset_optimizer
and not args.reset_meters
):
save_checkpoint.best = extra_state['best']
if extra_state is not None:
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
......@@ -117,6 +122,8 @@ def load_checkpoint(args, trainer):
else:
epoch_itr = trainer.get_train_iterator(epoch=0)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr
......
......@@ -333,14 +333,16 @@ def add_checkpoint_args(parser):
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--reset-dataloader', action='store_true',
help='if set, does not reload dataloader state from the checkpoint')
group.add_argument('--reset-lr-scheduler', action='store_true',
help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--reset-meters', action='store_true',
help='if set, does not load meters from the checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
......@@ -125,7 +125,7 @@ class Trainer(object):
extra_state['train_meters'] = self.meters
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion,
self.optimizer, self.lr_scheduler, self._num_updates,
self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state,
)
......@@ -171,7 +171,7 @@ class Trainer(object):
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates']
self.set_num_updates(last_optim['num_updates'])
if extra_state is not None:
epoch = extra_state['train_iterator']['epoch']
......@@ -179,7 +179,6 @@ class Trainer(object):
filename, epoch, self.get_num_updates()))
self.lr_step(epoch)
self.lr_step_update(self.get_num_updates())
if 'train_meters' in extra_state:
self.meters.update(extra_state['train_meters'])
......@@ -328,10 +327,7 @@ class Trainer(object):
# take an optimization step
self.optimizer.step()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
self.set_num_updates(self.get_num_updates() + 1)
# task specific update per step
self.task.update_step(self._num_updates)
......@@ -449,11 +445,13 @@ class Trainer(object):
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
_lr = self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps
return self.lr_step_update()
def lr_step_update(self, num_updates):
def lr_step_update(self):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
return self.lr_scheduler.step_update(self.get_num_updates())
def get_lr(self):
"""Get the current learning rate."""
......@@ -473,6 +471,11 @@ class Trainer(object):
"""Get the number of parameters updates."""
return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
self.lr_step_update()
def _prepare_sample(self, sample):
if sample is None or len(sample) == 0:
return None
......
......@@ -56,6 +56,9 @@ class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.args_mock = MagicMock()
self.args_mock.optimizer_overrides = '{}'
self.args_mock.reset_dataloader = False
self.args_mock.reset_meters = False
self.args_mock.reset_optimizer = False
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment