Commit fb18be00 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Validate on all sets based on --save-interval-updates

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/693

Differential Revision: D15174831

fbshipit-source-id: 98688b1269ead5694e5116659ff64507d3c0d1c0
parent 4a30a5f6
......@@ -134,7 +134,7 @@ def train(args, trainer, task, epoch_itr):
)
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
......@@ -159,7 +159,7 @@ def train(args, trainer, task, epoch_itr):
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment