Commit 13aa36cf authored by Myle Ott's avatar Myle Ott
Browse files

Small fixes

parent c778a31e
......@@ -13,12 +13,11 @@ import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar
from fairseq import criterions, models, options, progress_bar, utils
from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args):
......@@ -85,9 +84,9 @@ def main(args):
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
......@@ -290,13 +289,16 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
save_checkpoint.best = min(val_loss, getattr(save_checkpoint, 'best', val_loss))
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'best': save_checkpoint.best,
'best': prev_best,
'end_of_epoch': end_of_epoch,
'epoch': epoch,
'val_loss': val_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment