"vscode:/vscode.git/clone" did not exist on "ba1549244c7c1b03c1a7194a399c81ee1c48b7bc"
Commit 2dc074d8 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

add flag that allows keeping optimizer config

adds -reset-optimizer, --reset-lr-scheduler, and --optimizer-overrides flags
parent 6e3685ad
......@@ -81,9 +81,9 @@ class FP16Trainer(Trainer):
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename)
extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
......
......@@ -52,7 +52,7 @@ class FairseqOptimizer(object):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
......@@ -62,9 +62,10 @@ class FairseqOptimizer(object):
"""
self.optimizer.load_state_dict(state_dict)
if optimizer_overrides is not None and len(optimizer_overrides) > 0:
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self.optimizer_config)
group.update(optimizer_overrides)
def step(self, closure=None):
"""Performs a single optimization step."""
......
......@@ -228,6 +228,12 @@ def add_checkpoint_args(parser):
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--reset-lr-scheduler', action='store_true',
help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
......@@ -80,23 +80,28 @@ class Trainer(object):
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, load_optim=True):
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state, optim_history, last_optim_state = \
extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
if last_optim_state is not None:
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates']
......
......@@ -18,11 +18,11 @@ class TestCharacterTokenEmbedder(unittest.TestCase):
vocab.add_symbol('hello')
vocab.add_symbol('there')
embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5)
embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
max_len = max(len(s) for s in test_sents)
input = torch.LongTensor(len(test_sents), max_len + 2)
input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
for i in range(len(test_sents)):
input[i][0] = vocab.eos()
for j in range(len(test_sents[i])):
......
......@@ -302,7 +302,8 @@ def load_checkpoint(args, trainer, epoch_itr):
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment