Commit 89e077c3 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add additional options for configuring writing of checkpoints

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/697

Differential Revision: D16068465

Pulled By: myleott

fbshipit-source-id: c2563c3c682e7e8406e6d7c8e895d8afbec551eb
parent c86d70cc
...@@ -27,6 +27,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -27,6 +27,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args): if args.no_save or not distributed_utils.is_master(args):
return return
def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b
write_timer = meters.StopwatchMeter() write_timer = meters.StopwatchMeter()
write_timer.start() write_timer.start()
...@@ -45,13 +48,13 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -45,13 +48,13 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
) )
checkpoint_conds['checkpoint_best.pt'] = ( checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best))
) )
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints
prev_best = getattr(save_checkpoint, 'best', val_loss) prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None: if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best) save_checkpoint.best = is_better(val_loss, prev_best)
extra_state = { extra_state = {
'train_iterator': epoch_itr.state_dict(), 'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss, 'val_loss': val_loss,
...@@ -235,9 +238,10 @@ def save_state( ...@@ -235,9 +238,10 @@ def save_state(
'num_updates': num_updates, 'num_updates': num_updates,
} }
], ],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state, 'extra_state': extra_state,
} }
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
......
...@@ -359,6 +359,14 @@ def add_checkpoint_args(parser): ...@@ -359,6 +359,14 @@ def add_checkpoint_args(parser):
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--no-last-checkpoints', action='store_true',
help='don\'t store last checkpoints')
group.add_argument('--no-save-optimizer-state', action='store_true',
help='don\'t save optimizer-state as part of checkpoint')
group.add_argument('--best-checkpoint-metric', type=str, default='loss',
help='metric to use for saving "best" checkpoints')
group.add_argument('--maximize-best-checkpoint-metric', action='store_true',
help='select the largest metric value for saving "best" checkpoints')
# fmt: on # fmt: on
return group return group
......
...@@ -238,7 +238,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -238,7 +238,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
stats[k] = meter.avg stats[k] = meter.avg
progress.print(stats, tag=subset, step=trainer.get_num_updates()) progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats['loss'].avg) valid_losses.append(stats[args.best_checkpoint_metric].avg)
return valid_losses return valid_losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment