Commit 0714080b authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Sequence generator bug fix.

parent e6d45d5c
...@@ -84,9 +84,9 @@ class SequenceGenerator(object): ...@@ -84,9 +84,9 @@ class SequenceGenerator(object):
with torch.no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b maxlen=int(maxlen_a*srclen + maxlen_b),
**net_input),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
**net_input,
) )
if timer is not None: if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos)) timer.stop(sum(len(h[0]['tokens']) for h in hypos))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment