Commit a919570b authored by Myle Ott's avatar Myle Ott
Browse files

Merge validate and val_loss functions (simplify train.py)

parent 6643d525
......@@ -88,19 +88,20 @@ def main(args):
first_val_loss = None
train_meter = StopwatchMeter()
train_meter.start()
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next_ds, epoch, dataset)
if epoch % args.validate_interval == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch)
valid_losses = validate(args, trainer, dataset, valid_subsets, epoch)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)
lr = trainer.lr_step(epoch, valid_losses[0])
# save checkpoint
if epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=first_val_loss)
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=valid_losses[0])
epoch += 1
next_ds = next(train_dataloader)
......@@ -135,6 +136,7 @@ def train(args, trainer, itr, epoch, dataset):
update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
......@@ -164,8 +166,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=first_val_loss)
valid_losses = validate(args, trainer, dataset, [first_valid], epoch)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=valid_losses[0])
if num_updates >= max_update:
break
......@@ -201,52 +203,54 @@ def get_training_stats(trainer):
return stats
def validate(args, trainer, dataset, subset, epoch, num_updates):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
max_positions_valid = (
trainer.get_model().max_encoder_positions(),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(
args, itr, epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def validate(args, trainer, dataset, subsets, epoch):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize dataloader
max_positions_valid = (
trainer.get_model().max_encoder_positions(),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(
args, itr, epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
return stats['valid_loss']
valid_losses.append(stats['valid_loss'])
return valid_losses
def get_valid_stats(trainer):
......@@ -271,14 +275,6 @@ def get_perplexity(loss):
return float('inf')
def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment