".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "332760d4b300f00a0d862e3cfe1495db3b1a14f9"
Commit 79bbe1d8 authored by theweiho's avatar theweiho Committed by Myle Ott
Browse files

Add load_optim option to load checkpoint but not optimizer state (#229)

parent 5d99e139
...@@ -80,23 +80,25 @@ class Trainer(object): ...@@ -80,23 +80,25 @@ class Trainer(object):
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename): def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \ extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.model) utils.load_model_state(filename, self.model)
if last_optim_state is not None: if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
# only reload optimizer and lr_scheduler if they match if load_optim:
last_optim = self._optim_history[-1] self._optim_history = optim_history
if last_optim['criterion_name'] == self.criterion.__class__.__name__: # only reload optimizer and lr_scheduler if they match
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) last_optim = self._optim_history[-1]
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__: if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state) self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
if extra_state is not None and 'train_meters' in extra_state: if extra_state is not None and 'train_meters' in extra_state:
self.meters = extra_state['train_meters'] self.meters = extra_state['train_meters']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment