Commit 67af40c9 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

allow specifying max_tokens for generation

parent a5e49364
...@@ -106,7 +106,7 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -106,7 +106,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='Ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', default=6000, type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
......
...@@ -16,6 +16,10 @@ from fairseq.sequence_scorer import SequenceScorer ...@@ -16,6 +16,10 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args): def main(args):
assert args.path is not None, '--path required for generation!' assert args.path is not None, '--path required for generation!'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args) print(args)
assert not args.sampling or args.nbest == args.beam, \ assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam' '--sampling requires --nbest to be equal to --beam'
...@@ -58,12 +62,13 @@ def main(args): ...@@ -58,12 +62,13 @@ def main(args):
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded) # Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models) max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
args.gen_subset, args.gen_subset,
max_sentences=args.max_sentences or 128, max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
) )
......
...@@ -18,6 +18,10 @@ from fairseq.meters import AverageMeter, StopwatchMeter ...@@ -18,6 +18,10 @@ from fairseq.meters import AverageMeter, StopwatchMeter
def main(args): def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
print(args) print(args)
if not torch.cuda.is_available(): if not torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment