"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fa71fb447803d269d1050fe5081e2cb577b04b94"
Commit bccfddbb authored by Xutai Ma's avatar Xutai Ma Committed by Facebook Github Bot
Browse files

add --max-tokens-valid option for validation

Summary: Add the max-token-valid option. Sometime a separate max batch tokens for validation may be helpful, for example when there is a long sequence in validation set thats larger than max_tokens (it's rare in MT but could happen in ASR or AST).

Reviewed By: myleott

Differential Revision: D16076951

fbshipit-source-id: ae7f4218594580b9450a8196d7afa1e7e2018aee
parent 39bbc9a5
...@@ -114,6 +114,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -114,6 +114,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
# Post-process args. # Post-process args.
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
if hasattr(args, 'max_tokens_valid') and args.max_tokens_valid is None:
args.max_tokens_valid = args.max_tokens
if getattr(args, 'memory_efficient_fp16', False): if getattr(args, 'memory_efficient_fp16', False):
args.fp16 = True args.fp16 = True
...@@ -252,6 +254,9 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -252,6 +254,9 @@ def add_dataset_args(parser, train=False, gen=False):
help='validate every N epochs') help='validate every N epochs')
group.add_argument('--disable-validation', action='store_true', group.add_argument('--disable-validation', action='store_true',
help='disable validation') help='disable validation')
group.add_argument('--max-tokens-valid', type=int, metavar='N',
help='maximum number of tokens in a validation batch'
' (defaults to --max-tokens)')
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
......
...@@ -198,7 +198,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -198,7 +198,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
# Initialize data iterator # Initialize data iterator
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=task.dataset(subset), dataset=task.dataset(subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens_valid,
max_sentences=args.max_sentences_valid, max_sentences=args.max_sentences_valid,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
task.max_positions(), task.max_positions(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment