Commit 79bbe1d8 authored by theweiho's avatar theweiho Committed by Myle Ott
Browse files

Add load_optim option to load checkpoint but not optimizer state (#229)

parent 5d99e139
......@@ -80,23 +80,25 @@ class Trainer(object):
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename):
def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \
extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self._num_updates = last_optim['num_updates']
self._num_updates = last_optim['num_updates']
if extra_state is not None and 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment