"docs/vscode:/vscode.git/clone" did not exist on "fadb5ae5f73ab67bdffaac096da1f217ee80fb80"
Commit e6d45d5c authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Generator: net_input instead of manual src_tokens.

parent 25524f19
...@@ -83,10 +83,9 @@ class SequenceGenerator(object): ...@@ -83,10 +83,9 @@ class SequenceGenerator(object):
timer.start() timer.start()
with torch.no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
input['src_tokens'],
input['src_lengths'],
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b
**net_input),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
) )
if timer is not None: if timer is not None:
...@@ -97,12 +96,13 @@ class SequenceGenerator(object): ...@@ -97,12 +96,13 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, beam_size=None, maxlen=None, prefix_tokens=None, **net_input):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with torch.no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens) return self._generate(beam_size, maxlen, prefix_tokens, **net_input)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, beam_size=None, maxlen=None, prefix_tokens=None, **net_input):
src_tokens = net_input['src_tokens']
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -121,10 +121,10 @@ class SequenceGenerator(object): ...@@ -121,10 +121,10 @@ class SequenceGenerator(object):
incremental_states[model] = None incremental_states[model] = None
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder( encoder_out = model.encoder(**net_input)
src_tokens.repeat(1, beam_size).view(-1, srclen), new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1), new_order = new_order.to(net_input['src_tokens'].device)
) encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
# initialize buffers # initialize buffers
......
...@@ -85,7 +85,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict) generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -104,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case # Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False) generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -122,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self): def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -140,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -140,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self): def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos]) self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
...@@ -157,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -157,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2) generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
...@@ -174,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -174,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False) generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(src_tokens=self.src_tokens, src_lengths=self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment