Commit 10bf4074 authored by Myle Ott's avatar Myle Ott
Browse files

Rebuild optimizer when loading checkpoints

parent 9f3ccaa6
...@@ -157,11 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -157,11 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return extra_state return extra_state
def _async_load_checkpoint(self, rank, device_id, filename): def _async_load_checkpoint(self, rank, device_id, filename):
extra_state, self._optim_history = utils.load_state( extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, self.criterion, self.optimizer, filename, self.model, cuda_device=device_id)
self.lr_scheduler, cuda_device=device_id)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self.lr_scheduler.best = last_optim['best_loss']
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group.update(self._override_optim_state) group.update(self._override_optim_state)
return extra_state return extra_state
def set_seed(self, seed): def set_seed(self, seed):
......
...@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_ ...@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None): def load_model_state(filename, model, cuda_device=None):
if not os.path.exists(filename): if not os.path.exists(filename):
return None, [] return None, [], None
if cuda_device is None: if cuda_device is None:
state = torch.load(filename) state = torch.load(filename)
else: else:
...@@ -103,14 +103,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device= ...@@ -103,14 +103,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
raise Exception('Cannot load model parameters from checkpoint, ' raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match') 'please ensure that the architectures match')
# only load optimizer and lr_scheduler if they match with the checkpoint return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
optim_history = state['optimizer_history']
last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(state['last_optimizer_state'])
lr_scheduler.best = last_optim['best_loss']
return state['extra_state'], optim_history
def _upgrade_state_dict(state): def _upgrade_state_dict(state):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment