Commit 891f2307 authored by Mehdi Drissi's avatar Mehdi Drissi
Browse files

Merge branch 'master' of https://github.com/pytorch/fairseq into minor_fixes

parents 762956a5 b458977a
......@@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths):
return src_tokens
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out.index_select(0, new_order)
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
......
......@@ -287,16 +287,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for fn in checkpoints:
if os.path.exists(fn):
os.remove(fn)
if not end_of_epoch and args.keep_interval_updates > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
else:
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment