Commit 437c2386 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Speed up saving checkpoints (#703)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/703

It's better to write one checkpoint and copy it, rather than repeatedly pickling the model via torch.save.

Differential Revision: D15213778

fbshipit-source-id: 27dad39853b09dab7f0e11c030313019f035dbb0
parent cf17068a
...@@ -14,6 +14,7 @@ import itertools ...@@ -14,6 +14,7 @@ import itertools
import math import math
import os import os
import random import random
import shutil
import torch import torch
...@@ -282,16 +283,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -282,16 +283,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoint_conds = collections.OrderedDict() checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0 epoch % args.save_interval == 0
) )
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0 updates % args.save_interval_updates == 0
) )
checkpoint_conds['checkpoint_best.pt'] = ( checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
) )
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
...@@ -307,8 +308,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -307,8 +308,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0: if len(checkpoints) > 0:
for cp in checkpoints: trainer.save_checkpoint(checkpoints[0], extra_state)
trainer.save_checkpoint(cp, extra_state) for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop() write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment