"csrc/vscode:/vscode.git/clone" did not exist on "9ccee9c051cfabcdf2919fa2c1f69c11a72bf23d"
Commit 79bbe1d8 authored by theweiho's avatar theweiho Committed by Myle Ott
Browse files

Add load_optim option to load checkpoint but not optimizer state (#229)

parent 5d99e139
...@@ -80,15 +80,17 @@ class Trainer(object): ...@@ -80,15 +80,17 @@ class Trainer(object):
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename): def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \ extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.model) utils.load_model_state(filename, self.model)
if last_optim_state is not None: if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__: if last_optim['criterion_name'] == self.criterion.__class__.__name__:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment