Commit a3e4c4c3 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

remove unused verbose option & make arguments to averaging script nicer

parent 7c07e87c
...@@ -100,22 +100,32 @@ def main(): ...@@ -100,22 +100,32 @@ def main():
help='Write the new checkpoint containing the averaged weights to this ' help='Write the new checkpoint containing the averaged weights to this '
'path.', 'path.',
) )
parser.add_argument( num_group = parser.add_mutually_exclusive_group()
'--num', num_group.add_argument(
'--num-epoch-checkpoints',
type=int, type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those', 'and average last this many of them.',
) )
parser.add_argument( num_group.add_argument(
'--update-based-checkpoints', '--num-update-checkpoints',
action='store_true', type=int,
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints' help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.',
) )
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num is not None: num = None
args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints) is_update_based = False
if args.num_update_checkpoints is not None:
num = args.num_update_checkpoints
is_update_based = True
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
if num is not None:
args.inputs = last_n_checkpoints(args.inputs, num, is_update_based)
print('averaging checkpoints: ', args.inputs) print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs) new_state = average_checkpoints(args.inputs)
......
...@@ -203,7 +203,7 @@ def get_training_stats(trainer): ...@@ -203,7 +203,7 @@ def get_training_stats(trainer):
return stats return stats
def validate(args, trainer, dataset, subset, epoch, num_updates, verbose): def validate(args, trainer, dataset, subset, epoch, num_updates):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader # Initialize dataloader
...@@ -237,16 +237,6 @@ def validate(args, trainer, dataset, subset, epoch, num_updates, verbose): ...@@ -237,16 +237,6 @@ def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
for sample in progress: for sample in progress:
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
if verbose:
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# log validation stats # log validation stats
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
...@@ -283,7 +273,7 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None): ...@@ -283,7 +273,7 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set # evaluate on validate set
subsets = args.valid_subset.split(',') subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first # we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets] losses = [validate(args, trainer, dataset, subset, epoch, num_updates) for subset in subsets]
return losses[0] if len(losses) > 0 else None return losses[0] if len(losses) > 0 else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment