Commit b71f8f45 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --disable-validation

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/592

Differential Revision: D15415499

Pulled By: myleott

fbshipit-source-id: 87ba09b9b38501daebd95bbf28815e048c78f9a3
parent 4fac3b60
...@@ -234,6 +234,10 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -234,6 +234,10 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--valid-subset', default='valid', metavar='SPLIT', group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation' help='comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)') ' (train, valid, valid1, test, test1)')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
group.add_argument('--disable-validation', action='store_true',
help='disable validation')
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
...@@ -344,8 +348,6 @@ def add_checkpoint_args(parser): ...@@ -344,8 +348,6 @@ def add_checkpoint_args(parser):
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
# fmt: on # fmt: on
return group return group
......
...@@ -80,8 +80,10 @@ def main(args, init_distributed=False): ...@@ -80,8 +80,10 @@ def main(args, init_distributed=False):
# train for one epoch # train for one epoch
train(args, trainer, task, epoch_itr) train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0: if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
else:
valid_losses = [None]
# only use first validation loss to update the learning rate # only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
...@@ -138,7 +140,12 @@ def train(args, trainer, task, epoch_itr): ...@@ -138,7 +140,12 @@ def train(args, trainer, task, epoch_itr):
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: if (
not args.disable_validation
and args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment