"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "571bc1ea118297fb60e95be5e3e162839381aa48"
Commit c49c292c authored by Wei Ho's avatar Wei Ho Committed by Facebook Github Bot
Browse files

Add CheckpointManager to keep avg checkpoint weights in memory to reduce disk...

Add CheckpointManager to keep avg checkpoint weights in memory to reduce disk read when averaging + various checkpoint refactoring

Summary: Pull Request resolved: https://github.com/pytorch/translate/pull/315

Reviewed By: akinh

Differential Revision: D13510446

fbshipit-source-id: 22a6594af9253130a93e638285a47183a974e0de
parent 829bd8ce
...@@ -119,7 +119,7 @@ class Trainer(object): ...@@ -119,7 +119,7 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
utils.save_state( utils.save_state(
filename, self.args, self.get_model(), self.criterion, self.optimizer, filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
......
...@@ -39,7 +39,7 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): ...@@ -39,7 +39,7 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
return state_dict return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler, def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None): num_updates, optim_history=None, extra_state=None):
if optim_history is None: if optim_history is None:
optim_history = [] optim_history = []
...@@ -47,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -47,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {} extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'model': model.state_dict() if model else {}, 'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment