Commit 89e077c3 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add additional options for configuring writing of checkpoints

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/697

Differential Revision: D16068465

Pulled By: myleott

fbshipit-source-id: c2563c3c682e7e8406e6d7c8e895d8afbec551eb
parent c86d70cc
......@@ -27,6 +27,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b
write_timer = meters.StopwatchMeter()
write_timer.start()
......@@ -45,13 +48,13 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
(not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best))
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
save_checkpoint.best = is_better(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
......@@ -235,9 +238,10 @@ def save_state(
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename)
......
......@@ -359,6 +359,14 @@ def add_checkpoint_args(parser):
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--no-last-checkpoints', action='store_true',
help='don\'t store last checkpoints')
group.add_argument('--no-save-optimizer-state', action='store_true',
help='don\'t save optimizer-state as part of checkpoint')
group.add_argument('--best-checkpoint-metric', type=str, default='loss',
help='metric to use for saving "best" checkpoints')
group.add_argument('--maximize-best-checkpoint-metric', action='store_true',
help='select the largest metric value for saving "best" checkpoints')
# fmt: on
return group
......
......@@ -238,7 +238,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
stats[k] = meter.avg
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats['loss'].avg)
valid_losses.append(stats[args.best_checkpoint_metric].avg)
return valid_losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment