Commit 48631f7a authored by Myle Ott's avatar Myle Ott
Browse files

Allow --max-len-a to be a float

parent 813352e1
...@@ -85,11 +85,11 @@ def add_generation_args(parser): ...@@ -85,11 +85,11 @@ def add_generation_args(parser):
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output') help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=int, metavar='N', group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequence of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N', group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring') help='remove BPE tokens before scoring')
......
...@@ -48,7 +48,7 @@ class SequenceGenerator(object): ...@@ -48,7 +48,7 @@ class SequenceGenerator(object):
self.positions = self.positions.cuda() self.positions = self.positions.cuda()
return self return self
def generate_batched_itr(self, data_itr, maxlen_a=0, maxlen_b=200, def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
cuda_device=None, timer=None): cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
...@@ -69,7 +69,7 @@ class SequenceGenerator(object): ...@@ -69,7 +69,7 @@ class SequenceGenerator(object):
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'], hypos = self.generate(input['src_tokens'], input['src_positions'],
maxlen=(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id']):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment