Commit bd46c5ec authored by Myle Ott's avatar Myle Ott
Browse files

Prefer command-line configuration over checkpoint for optimizer state

parent 19fafae6
...@@ -77,21 +77,39 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -77,21 +77,39 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = 0 self._max_bsz_seen = 0
def _build_optimizer(self): def _build_optimizer(self):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self._override_optim_state = {}
if self.args.optimizer == 'adagrad': if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr[0], self._override_optim_state = {
weight_decay=self.args.weight_decay) 'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adagrad(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'adam': elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr[0], self._override_optim_state = {
betas=eval(self.args.adam_betas), 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag': elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr[0], self._override_optim_state = {
momentum=self.args.momentum, 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return NAG(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'sgd': elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr[0], self._override_optim_state = {
momentum=self.args.momentum, 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return torch.optim.SGD(self.model.parameters(), **self._override_optim_state)
else: else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer)) raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
...@@ -142,6 +160,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -142,6 +160,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
extra_state, self._optim_history = utils.load_state( extra_state, self._optim_history = utils.load_state(
filename, self.model, self.criterion, self.optimizer, filename, self.model, self.criterion, self.optimizer,
self.lr_scheduler, cuda_device=device_id) self.lr_scheduler, cuda_device=device_id)
for group in self.optimizer.param_groups:
group.update(self._override_optim_state)
return extra_state return extra_state
def set_seed(self, seed): def set_seed(self, seed):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment