Commit a091d239 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

respect max tokens and ignore invalid inputs when evaluating lm

parent cc85d411
...@@ -38,10 +38,12 @@ def main(args): ...@@ -38,10 +38,12 @@ def main(args):
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(), max_positions=model.max_positions(),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment