Commit 13aa36cf authored by Myle Ott's avatar Myle Ott
Browse files

Small fixes

parent c778a31e
...@@ -13,12 +13,11 @@ import torch ...@@ -13,12 +13,11 @@ import torch
from itertools import islice from itertools import islice
from fairseq import criterions, models, options, progress_bar from fairseq import criterions, models, options, progress_bar, utils
from fairseq.data import data_utils, data_loaders from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args): def main(args):
...@@ -85,9 +84,9 @@ def main(args): ...@@ -85,9 +84,9 @@ def main(args):
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
lr = trainer.get_lr() lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',') valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
...@@ -290,13 +289,16 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): ...@@ -290,13 +289,16 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
updates % args.save_interval_updates == 0 updates % args.save_interval_updates == 0
) )
checkpoint_conds['checkpoint_best.pt'] = ( checkpoint_conds['checkpoint_best.pt'] = (
not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
) )
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
save_checkpoint.best = min(val_loss, getattr(save_checkpoint, 'best', val_loss)) prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = { extra_state = {
'best': save_checkpoint.best, 'best': prev_best,
'end_of_epoch': end_of_epoch, 'end_of_epoch': end_of_epoch,
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment