Commit 48631f7a authored by Myle Ott's avatar Myle Ott
Browse files

Allow --max-len-a to be a float

parent 813352e1
......@@ -85,11 +85,11 @@ def add_generation_args(parser):
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, '
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, '
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
......
......@@ -48,7 +48,7 @@ class SequenceGenerator(object):
self.positions = self.positions.cuda()
return self
def generate_batched_itr(self, data_itr, maxlen_a=0, maxlen_b=200,
def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations.
......@@ -69,7 +69,7 @@ class SequenceGenerator(object):
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'],
maxlen=(maxlen_a*srclen + maxlen_b))
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id']):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment