Commit 6643d525 authored by Myle Ott's avatar Myle Ott
Browse files

Use symlinks for redundant checkpoints

parent 24d7de44
......@@ -99,8 +99,8 @@ def main(args):
lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint
if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
if epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1
next_ds = next(train_dataloader)
......@@ -163,10 +163,9 @@ def train(args, trainer, itr, epoch, dataset):
trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and \
num_updates % args.save_interval_updates == 0:
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update:
break
......@@ -280,38 +279,49 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0:
return
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
save_checkpoint.best = min(val_loss, getattr(save_checkpoint, 'best', val_loss))
extra_state = {
'best': save_checkpoint.best,
'end_of_epoch': end_of_epoch,
'epoch': epoch,
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
'end_of_epoch': end_of_epoch,
}
if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for fn in checkpoints:
if os.path.exists(fn):
os.remove(fn)
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
extra_state['best'] = val_loss
trainer.save_checkpoint(best_filename, extra_state)
extra_state['best'] = save_checkpoint.best
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def load_checkpoint(args, trainer, train_dataloader):
os.makedirs(args.save_dir, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment