Commit a919570b authored by Myle Ott's avatar Myle Ott
Browse files

Merge validate and val_loss functions (simplify train.py)

parent 6643d525
...@@ -88,19 +88,20 @@ def main(args): ...@@ -88,19 +88,20 @@ def main(args):
first_val_loss = None first_val_loss = None
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
train(args, trainer, next_ds, epoch, dataset) train(args, trainer, next_ds, epoch, dataset)
if epoch % args.validate_interval == 0: if epoch % args.validate_interval == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch) valid_losses = validate(args, trainer, dataset, valid_subsets, epoch)
# only use first validation loss to update the learning rate # only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss) lr = trainer.lr_step(epoch, valid_losses[0])
# save checkpoint # save checkpoint
if epoch % args.save_interval == 0: if epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=first_val_loss) save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=valid_losses[0])
epoch += 1 epoch += 1
next_ds = next(train_dataloader) next_ds = next(train_dataloader)
...@@ -135,6 +136,7 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -135,6 +136,7 @@ def train(args, trainer, itr, epoch, dataset):
update_freq = args.update_freq[-1] update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
num_batches = len(itr) num_batches = len(itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
...@@ -164,8 +166,8 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -164,8 +166,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates) valid_losses = validate(args, trainer, dataset, [first_valid], epoch)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=first_val_loss) save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=valid_losses[0])
if num_updates >= max_update: if num_updates >= max_update:
break break
...@@ -201,52 +203,54 @@ def get_training_stats(trainer): ...@@ -201,52 +203,54 @@ def get_training_stats(trainer):
return stats return stats
def validate(args, trainer, dataset, subset, epoch, num_updates): def validate(args, trainer, dataset, subsets, epoch):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
# Initialize dataloader for subset in subsets:
max_positions_valid = ( # Initialize dataloader
trainer.get_model().max_encoder_positions(), max_positions_valid = (
trainer.get_model().max_decoder_positions(), trainer.get_model().max_encoder_positions(),
) trainer.get_model().max_decoder_positions(),
itr = dataset.eval_dataloader( )
subset, itr = dataset.eval_dataloader(
max_tokens=args.max_tokens, subset,
max_sentences=args.max_sentences_valid, max_tokens=args.max_tokens,
max_positions=max_positions_valid, max_sentences=args.max_sentences_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, max_positions=max_positions_valid,
descending=True, # largest batch first to warm the caching allocator skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
shard_id=args.distributed_rank, descending=True, # largest batch first to warm the caching allocator
num_shards=args.distributed_world_size, shard_id=args.distributed_rank,
) num_shards=args.distributed_world_size,
progress = progress_bar.build_progress_bar( )
args, itr, epoch, progress = progress_bar.build_progress_bar(
prefix='valid on \'{}\' subset'.format(subset), args, itr, epoch,
no_progress_bar='simple' prefix='valid on \'{}\' subset'.format(subset),
) no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']: # reset validation loss meters
meter = trainer.get_meter(k) for k in ['valid_loss', 'valid_nll_loss']:
if meter is not None: meter = trainer.get_meter(k)
meter.reset() if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
extra_meters = collections.defaultdict(lambda: AverageMeter()) # log validation stats
for sample in progress: stats = get_valid_stats(trainer)
log_output = trainer.valid_step(sample) for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
for k, v in log_output.items(): valid_losses.append(stats['valid_loss'])
if k in ['loss', 'nll_loss', 'sample_size']: return valid_losses
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
return stats['valid_loss']
def get_valid_stats(trainer): def get_valid_stats(trainer):
...@@ -271,14 +275,6 @@ def get_perplexity(loss): ...@@ -271,14 +275,6 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0: if args.no_save or args.distributed_rank > 0:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment