Commit bfeb7732 authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Pass encoder_input to generator, rather than src_tokens/src_lengths.

parent 8bd8ec8f
...@@ -78,13 +78,18 @@ class SequenceGenerator(object): ...@@ -78,13 +78,18 @@ class SequenceGenerator(object):
if 'net_input' not in s: if 'net_input' not in s:
continue continue
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) # model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items()
if k != 'prev_output_tokens'
}
srclen = encoder_input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
with torch.no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
input['src_tokens'], encoder_input,
input['src_lengths'],
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
...@@ -97,12 +102,23 @@ class SequenceGenerator(object): ...@@ -97,12 +102,23 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations.
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
"""
with torch.no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens) return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate"""
src_tokens = encoder_input['src_tokens']
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -121,10 +137,10 @@ class SequenceGenerator(object): ...@@ -121,10 +137,10 @@ class SequenceGenerator(object):
incremental_states[model] = None incremental_states[model] = None
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder( encoder_out = model.encoder(**encoder_input)
src_tokens.repeat(1, beam_size).view(-1, srclen), new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1), new_order = new_order.to(src_tokens.device)
) encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
# initialize buffers # initialize buffers
......
...@@ -145,9 +145,9 @@ def main(args): ...@@ -145,9 +145,9 @@ def main(args):
tokens = tokens.cuda() tokens = tokens.cuda()
lengths = lengths.cuda() lengths = lengths.cuda()
encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
translations = translator.generate( translations = translator.generate(
tokens, encoder_input,
lengths,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b), maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
) )
......
...@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
[self.w1, self.w2, self.eos], [self.w1, self.w2, self.eos],
]) ])
self.src_lengths = torch.LongTensor([2, 2]) self.src_lengths = torch.LongTensor([2, 2])
self.encoder_input = {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths,
}
args = argparse.Namespace() args = argparse.Namespace()
unk = 0. unk = 0.
...@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict) generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case # Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False) generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self): def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self): def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos]) self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
...@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2) generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False) generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase): ...@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
[self.model], self.tgt_dict, [self.model], self.tgt_dict,
beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0., beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
) )
hypos = generator.generate(self.src_tokens, self.src_lengths) encoder_input = {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}
hypos = generator.generate(encoder_input)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment