"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3fd31eef518b73ee592f82435f3d370a716ead4f"
Commit a091d239 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

respect max tokens and ignore invalid inputs when evaluating lm

parent cc85d411
...@@ -38,10 +38,12 @@ def main(args): ...@@ -38,10 +38,12 @@ def main(args):
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(), max_positions=model.max_positions(),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment