Commit 2dc074d8 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

add flag that allows keeping optimizer config

adds -reset-optimizer, --reset-lr-scheduler, and --optimizer-overrides flags
parent 6e3685ad
...@@ -81,9 +81,9 @@ class FP16Trainer(Trainer): ...@@ -81,9 +81,9 @@ class FP16Trainer(Trainer):
extra_state['loss_scale'] = self.scaler.loss_scale extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state) super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename) extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides)
if extra_state is not None and 'loss_scale' in extra_state: if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale'] self.scaler.loss_scale = extra_state['loss_scale']
return extra_state return extra_state
......
...@@ -52,7 +52,7 @@ class FairseqOptimizer(object): ...@@ -52,7 +52,7 @@ class FairseqOptimizer(object):
"""Return the optimizer's state dict.""" """Return the optimizer's state dict."""
return self.optimizer.state_dict() return self.optimizer.state_dict()
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict. """Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer In general we should prefer the configuration of the existing optimizer
...@@ -62,9 +62,10 @@ class FairseqOptimizer(object): ...@@ -62,9 +62,10 @@ class FairseqOptimizer(object):
""" """
self.optimizer.load_state_dict(state_dict) self.optimizer.load_state_dict(state_dict)
if optimizer_overrides is not None and len(optimizer_overrides) > 0:
# override learning rate, momentum, etc. with latest values # override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group.update(self.optimizer_config) group.update(optimizer_overrides)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
......
...@@ -228,6 +228,12 @@ def add_checkpoint_args(parser): ...@@ -228,6 +228,12 @@ def add_checkpoint_args(parser):
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--reset-lr-scheduler', action='store_true',
help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
...@@ -80,23 +80,28 @@ class Trainer(object): ...@@ -80,23 +80,28 @@ class Trainer(object):
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename, load_optim=True): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, optim_history, last_optim_state = \ extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model) utils.load_model_state(filename, self.model)
if last_optim_state is not None: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
......
...@@ -18,11 +18,11 @@ class TestCharacterTokenEmbedder(unittest.TestCase): ...@@ -18,11 +18,11 @@ class TestCharacterTokenEmbedder(unittest.TestCase):
vocab.add_symbol('hello') vocab.add_symbol('hello')
vocab.add_symbol('there') vocab.add_symbol('there')
embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5) embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
max_len = max(len(s) for s in test_sents) max_len = max(len(s) for s in test_sents)
input = torch.LongTensor(len(test_sents), max_len + 2) input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
for i in range(len(test_sents)): for i in range(len(test_sents)):
input[i][0] = vocab.eos() input[i][0] = vocab.eos()
for j in range(len(test_sents[i])): for j in range(len(test_sents[i])):
......
...@@ -302,7 +302,8 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -302,7 +302,8 @@ def load_checkpoint(args, trainer, epoch_itr):
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path) extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None: if extra_state is not None:
# replay train iterator to match checkpoint # replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator']) epoch_itr.load_state_dict(extra_state['train_iterator'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment