"src/vscode:/vscode.git/clone" did not exist on "325a5de3a9acc97534a4446ce9dd4147efcd61a0"
Commit 295ccee9 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

save best val loss in checkpoint

save best val loss in checkpoint and also print best so far

this way when training continues from an existing checkpoint, we dont immediately override checkpoint_best with a worse loss
parent 316744d6
......@@ -243,6 +243,9 @@ def validate(args, trainer, dataset, subset, epoch, num_updates):
if num_updates is not None:
stats['num_updates'] = num_updates
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
progress.print(stats)
return stats['valid_loss']
......@@ -300,8 +303,10 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
extra_state['best'] = val_loss
trainer.save_checkpoint(best_filename, extra_state)
extra_state['best'] = save_checkpoint.best
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
......@@ -318,6 +323,9 @@ def load_checkpoint(args, trainer, train_dataloader):
end_of_epoch = extra_state.get('end_of_epoch', True)
trainer_updates = trainer.get_num_updates()
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment