Commit eea50f38 authored by Myle Ott's avatar Myle Ott
Browse files

Refactor model saving/loading to be more reusable

parent 3f970086
...@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _async_get_model(self, rank, device_id): def _async_get_model(self, rank, device_id):
return self.model return self.model
def save_checkpoint(self, args, epoch, batch_offset, val_loss=None): def save_checkpoint(self, filename, extra_state):
"""Save a checkpoint for the current model.""" """Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', args=args, epoch=epoch, self.call_async(0, '_async_save_checkpoint', filename=filename, extra_state=extra_state).gen()
batch_offset=batch_offset, val_loss=val_loss).gen()
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss): def _async_save_checkpoint(self, rank, device_id, filename, extra_state):
utils.save_checkpoint(args, epoch, batch_offset, self.model, self.criterion, utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.optimizer, self.lr_scheduler, val_loss, self._optim_history) self.lr_scheduler, self._optim_history, extra_state)
def load_checkpoint(self, filename): def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process.""" """Load a checkpoint into the model replicas in each process."""
...@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.call_async(rank, '_async_load_checkpoint', filename=filename) self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
epoch, batch_offset = results[0] extra_state = results[0]
return epoch, batch_offset return extra_state
def _async_load_checkpoint(self, rank, device_id, filename): def _async_load_checkpoint(self, rank, device_id, filename):
epoch, batch_offset, self._optim_history = utils.load_checkpoint( extra_state, self._optim_history = utils.load_state(
filename, self.model, self.criterion, self.optimizer, self.lr_scheduler, filename, self.model, self.criterion, self.optimizer,
cuda_device=device_id) self.lr_scheduler, cuda_device=device_id)
return epoch, batch_offset return extra_state
def train_step(self, samples): def train_step(self, samples):
"""Do forward, backward and gradient step in parallel.""" """Do forward, backward and gradient step in parallel."""
......
...@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs): ...@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_scheduler, def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_history=None, extra_state=None):
val_loss=None, optim_history=None):
if optim_history is None: if optim_history is None:
optim_history = [] optim_history = []
if extra_state is None:
extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'epoch': epoch,
'batch_offset': batch_offset,
'model': model.state_dict(), 'model': model.state_dict(),
'val_loss': val_loss,
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
...@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s ...@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s
'best_loss': lr_scheduler.best, 'best_loss': lr_scheduler.best,
} }
], ],
'extra_state': extra_state,
} }
torch_persistent_save(state_dict, filename)
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
torch_persistent_save(state_dict, epoch_filename)
assert val_loss is not None def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
torch_persistent_save(state_dict, best_filename)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
torch_persistent_save(state_dict, last_filename)
def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename): if not os.path.exists(filename):
return 1, 0, [] return None, []
if cuda_device is None: if cuda_device is None:
state = torch.load(filename) state = torch.load(filename)
else: else:
...@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de ...@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de
) )
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
# load model parameters
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
epoch = state['epoch'] + 1
batch_offset = state['batch_offset']
# only load optimizer and lr_scheduler if they match with the checkpoint # only load optimizer and lr_scheduler if they match with the checkpoint
opt_str = ''
optim_history = state['optimizer_history'] optim_history = state['optimizer_history']
last_optim = optim_history[-1] last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__: if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(last_optim['optimizer']) optimizer.load_state_dict(last_optim['optimizer'])
lr_scheduler.best = last_optim['best_loss'] lr_scheduler.best = last_optim['best_loss']
opt_str = '; criterion: {}'.format(last_optim['criterion_name'])
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else ''
print('| loaded checkpoint {} (epoch {}{}){}'.format(filename, epoch, opt_str, gpu_str))
return epoch, batch_offset, optim_history return state['extra_state'], optim_history
def _upgrade_state_dict(state): def _upgrade_state_dict(state):
...@@ -124,6 +104,16 @@ def _upgrade_state_dict(state): ...@@ -124,6 +104,16 @@ def _upgrade_state_dict(state):
] ]
del state['optimizer'] del state['optimizer']
del state['best_loss'] del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
return state return state
......
...@@ -62,16 +62,25 @@ def main(): ...@@ -62,16 +62,25 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens)) print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model # Build model and criterion
print('| model {}'.format(args.arch))
model = utils.build_model(args, dataset) model = utils.build_model(args, dataset)
criterion = utils.build_criterion(args, dataset) criterion = utils.build_criterion(args, dataset)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion) trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file)) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small # Train until the learning rate gets too small
val_loss = None val_loss = None
...@@ -89,7 +98,7 @@ def main(): ...@@ -89,7 +98,7 @@ def main():
if k == 0: if k == 0:
if not args.no_save: if not args.no_save:
# save checkpoint # save checkpoint
trainer.save_checkpoint(args, epoch, 0, val_loss) save_checkpoint(trainer, args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule # only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch) lr = trainer.lr_step(val_loss, epoch)
...@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
wps_meter.reset() wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0: if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format( fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, math.pow(2, loss_meter.avg)) loss_meter.avg, math.pow(2, loss_meter.avg))
...@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
t.write(fmt) t.write(fmt)
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus): def validate(args, epoch, trainer, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment