Unverified Commit ba5d7dcd authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Only save most recent optimizer state in checkpoints (#53)

parent f6ac1aec
...@@ -57,10 +57,10 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_ ...@@ -57,10 +57,10 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
'optimizer': optimizer.state_dict(),
'best_loss': lr_scheduler.best, 'best_loss': lr_scheduler.best,
} }
], ],
'last_optimizer_state': optimizer.state_dict(),
'extra_state': extra_state, 'extra_state': extra_state,
} }
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
...@@ -85,7 +85,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device= ...@@ -85,7 +85,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
optim_history = state['optimizer_history'] optim_history = state['optimizer_history']
last_optim = optim_history[-1] last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__: if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(last_optim['optimizer']) optimizer.load_state_dict(state['last_optimizer_state'])
lr_scheduler.best = last_optim['best_loss'] lr_scheduler.best = last_optim['best_loss']
return state['extra_state'], optim_history return state['extra_state'], optim_history
...@@ -98,10 +98,10 @@ def _upgrade_state_dict(state): ...@@ -98,10 +98,10 @@ def _upgrade_state_dict(state):
state['optimizer_history'] = [ state['optimizer_history'] = [
{ {
'criterion_name': criterions.CrossEntropyCriterion.__name__, 'criterion_name': criterions.CrossEntropyCriterion.__name__,
'optimizer': state['optimizer'],
'best_loss': state['best_loss'], 'best_loss': state['best_loss'],
}, },
] ]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer'] del state['optimizer']
del state['best_loss'] del state['best_loss']
# move extra_state into sub-dictionary # move extra_state into sub-dictionary
...@@ -114,6 +114,11 @@ def _upgrade_state_dict(state): ...@@ -114,6 +114,11 @@ def _upgrade_state_dict(state):
del state['epoch'] del state['epoch']
del state['batch_offset'] del state['batch_offset']
del state['val_loss'] del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
return state return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment