Commit eea50f38 authored by Myle Ott's avatar Myle Ott
Browse files

Refactor model saving/loading to be more reusable

parent 3f970086
......@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _async_get_model(self, rank, device_id):
return self.model
def save_checkpoint(self, args, epoch, batch_offset, val_loss=None):
def save_checkpoint(self, filename, extra_state):
"""Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', args=args, epoch=epoch,
batch_offset=batch_offset, val_loss=val_loss).gen()
self.call_async(0, '_async_save_checkpoint', filename=filename, extra_state=extra_state).gen()
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss):
utils.save_checkpoint(args, epoch, batch_offset, self.model, self.criterion,
self.optimizer, self.lr_scheduler, val_loss, self._optim_history)
def _async_save_checkpoint(self, rank, device_id, filename, extra_state):
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._optim_history, extra_state)
def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process."""
......@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas)
])
epoch, batch_offset = results[0]
return epoch, batch_offset
extra_state = results[0]
return extra_state
def _async_load_checkpoint(self, rank, device_id, filename):
epoch, batch_offset, self._optim_history = utils.load_checkpoint(
filename, self.model, self.criterion, self.optimizer, self.lr_scheduler,
cuda_device=device_id)
return epoch, batch_offset
extra_state, self._optim_history = utils.load_state(
filename, self.model, self.criterion, self.optimizer,
self.lr_scheduler, cuda_device=device_id)
return extra_state
def train_step(self, samples):
"""Do forward, backward and gradient step in parallel."""
......
......@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc())
def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_scheduler,
val_loss=None, optim_history=None):
def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'epoch': epoch,
'batch_offset': batch_offset,
'model': model.state_dict(),
'val_loss': val_loss,
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
......@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s
'best_loss': lr_scheduler.best,
}
],
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
torch_persistent_save(state_dict, epoch_filename)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
torch_persistent_save(state_dict, best_filename)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
torch_persistent_save(state_dict, last_filename)
def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename):
return 1, 0, []
return None, []
if cuda_device is None:
state = torch.load(filename)
else:
......@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de
)
state = _upgrade_state_dict(state)
# load model parameters
model.load_state_dict(state['model'])
epoch = state['epoch'] + 1
batch_offset = state['batch_offset']
# only load optimizer and lr_scheduler if they match with the checkpoint
opt_str = ''
optim_history = state['optimizer_history']
last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(last_optim['optimizer'])
lr_scheduler.best = last_optim['best_loss']
opt_str = '; criterion: {}'.format(last_optim['criterion_name'])
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else ''
print('| loaded checkpoint {} (epoch {}{}){}'.format(filename, epoch, opt_str, gpu_str))
return epoch, batch_offset, optim_history
return state['extra_state'], optim_history
def _upgrade_state_dict(state):
......@@ -124,6 +104,16 @@ def _upgrade_state_dict(state):
]
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
return state
......
......@@ -62,16 +62,25 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model
print('| model {}'.format(args.arch))
# Build model and criterion
model = utils.build_model(args, dataset)
criterion = utils.build_criterion(args, dataset)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
val_loss = None
......@@ -89,7 +98,7 @@ def main():
if k == 0:
if not args.no_save:
# save checkpoint
trainer.save_checkpoint(args, epoch, 0, val_loss)
save_checkpoint(trainer, args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch)
......@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
# ignore the first mini-batch in words-per-second calculation
wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1)
save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, math.pow(2, loss_meter.avg))
......@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
t.write(fmt)
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment