Commit 6643d525 authored by Myle Ott's avatar Myle Ott
Browse files

Use symlinks for redundant checkpoints

parent 24d7de44
...@@ -99,8 +99,8 @@ def main(args): ...@@ -99,8 +99,8 @@ def main(args):
lr = trainer.lr_step(epoch, first_val_loss) lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint # save checkpoint
if not args.no_save and epoch % args.save_interval == 0: if epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss) save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1 epoch += 1
next_ds = next(train_dataloader) next_ds = next(train_dataloader)
...@@ -163,10 +163,9 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -163,10 +163,9 @@ def train(args, trainer, itr, epoch, dataset):
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and \ if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates) first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss) save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update: if num_updates >= max_update:
break break
...@@ -280,38 +279,49 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None): ...@@ -280,38 +279,49 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
return losses[0] if len(losses) > 0 else None return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss): def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0:
return
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
save_checkpoint.best = min(val_loss, getattr(save_checkpoint, 'best', val_loss))
extra_state = { extra_state = {
'best': save_checkpoint.best,
'end_of_epoch': end_of_epoch,
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time, 'wall_time': trainer.get_meter('wall').elapsed_time,
'end_of_epoch': end_of_epoch,
} }
if end_of_epoch and not args.no_epoch_checkpoints: checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch)) if len(checkpoints) > 0:
trainer.save_checkpoint(epoch_filename, extra_state) for fn in checkpoints:
elif not end_of_epoch and args.keep_interval_updates > 0: if os.path.exists(fn):
checkpoint_filename = os.path.join(args.save_dir, os.remove(fn)
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates())) trainer.save_checkpoint(checkpoints[0], extra_state)
trainer.save_checkpoint(checkpoint_filename, extra_state) for fn in checkpoints[1:]:
# remove old checkpoints os.symlink(os.path.basename(checkpoints[0]), fn)
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]: for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk) os.remove(old_chk)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
extra_state['best'] = val_loss
trainer.save_checkpoint(best_filename, extra_state)
extra_state['best'] = save_checkpoint.best
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def load_checkpoint(args, trainer, train_dataloader): def load_checkpoint(args, trainer, train_dataloader):
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment