Commit 8c0ca1a0 authored by Myle Ott's avatar Myle Ott
Browse files

Diverse Beam Search

parent ba9f32cc
...@@ -305,6 +305,10 @@ def add_generation_args(parser): ...@@ -305,6 +305,10 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling') help='temperature for random sampling')
group.add_argument('--diverse-beam-groups', default=1, type=int, metavar='N',
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--print-alignment', action='store_true', group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens') help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
......
...@@ -80,6 +80,69 @@ class BeamSearch(Search): ...@@ -80,6 +80,69 @@ class BeamSearch(Search):
return self.scores_buf, self.indices_buf, self.beams_buf return self.scores_buf, self.indices_buf, self.beams_buf
class DiverseBeamSearch(Search):
"""Diverse Beam Search.
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models" for details.
We only implement the Hamming Diversity penalty here, which performed best
in the original paper.
"""
def __init__(self, tgt_dict, num_groups, diversity_strength):
super().__init__(tgt_dict)
self.num_groups = num_groups
self.diversity_strength = -diversity_strength
self.diversity_buf = None
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if beam_size % self.num_groups != 0:
raise ValueError(
'DiverseBeamSearch requires --beam to be divisible by the number of groups'
)
group_size = beam_size // self.num_groups
# initialize diversity penalty
if self.diversity_buf is None:
self.diversity_buf = lprobs.new()
torch.zeros(lprobs[:, 0, :].size(), out=self.diversity_buf)
scores_G, indices_G, beams_G = [], [], []
for g in range(self.num_groups):
lprobs_g = lprobs[:, g::self.num_groups, :]
scores_g = scores[:, g::self.num_groups, :] if step > 0 else None
# apply diversity penalty
if g > 0:
lprobs_g = torch.add(lprobs_g, self.diversity_strength, self.diversity_buf.unsqueeze(1))
else:
lprobs_g = lprobs_g.contiguous()
scores_buf, indices_buf, beams_buf = self.beam.step(step, lprobs_g, scores_g)
beams_buf.mul_(self.num_groups).add_(g)
scores_G.append(scores_buf.clone())
indices_G.append(indices_buf.clone())
beams_G.append(beams_buf.clone())
# update diversity penalty
self.diversity_buf.scatter_add_(
1,
indices_buf,
self.diversity_buf.new_ones(indices_buf.size())
)
# interleave results from different groups
self.scores_buf = torch.stack(scores_G, dim=2, out=self.scores_buf).view(bsz, -1)
self.indices_buf = torch.stack(indices_G, dim=2, out=self.indices_buf).view(bsz, -1)
self.beams_buf = torch.stack(beams_G, dim=2, out=self.beams_buf).view(bsz, -1)
return self.scores_buf, self.indices_buf, self.beams_buf
class Sampling(Search): class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1, sampling_temperature=1.): def __init__(self, tgt_dict, sampling_topk=-1, sampling_temperature=1.):
......
...@@ -18,6 +18,7 @@ class SequenceGenerator(object): ...@@ -18,6 +18,7 @@ class SequenceGenerator(object):
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True, self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False, normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1, sampling=False, sampling_topk=-1, sampling_temperature=1,
diverse_beam_groups=-1, diverse_beam_strength=0.5,
): ):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
...@@ -48,6 +49,8 @@ class SequenceGenerator(object): ...@@ -48,6 +49,8 @@ class SequenceGenerator(object):
if sampling: if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature) self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
else: else:
self.search = search.BeamSearch(tgt_dict) self.search = search.BeamSearch(tgt_dict)
...@@ -402,6 +405,7 @@ class SequenceGenerator(object): ...@@ -402,6 +405,7 @@ class SequenceGenerator(object):
active_mask, k=beam_size, dim=1, largest=False, active_mask, k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos) out=(_ignore, active_hypos)
) )
active_bbsz_idx = buffer('active_bbsz_idx') active_bbsz_idx = buffer('active_bbsz_idx')
torch.gather( torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos, cand_bbsz_idx, dim=1, index=active_hypos,
......
...@@ -71,11 +71,11 @@ def main(args): ...@@ -71,11 +71,11 @@ def main(args):
translator = SequenceScorer(models, task.target_dictionary) translator = SequenceScorer(models, task.target_dictionary)
else: else:
translator = SequenceGenerator( translator = SequenceGenerator(
models, task.target_dictionary, beam_size=args.beam, models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen, len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len, sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
sampling_temperature=args.sampling_temperature, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
) )
if use_cuda: if use_cuda:
......
...@@ -90,10 +90,11 @@ def main(args): ...@@ -90,10 +90,11 @@ def main(args):
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop), models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, len_penalty=args.lenpen, unk_penalty=args.unkpen,
minlen=args.min_len, sampling_temperature=args.sampling_temperature sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
) )
if use_cuda: if use_cuda:
......
...@@ -210,5 +210,104 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -210,5 +210,104 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertEqual(t1.ne(t2).long().sum(), 0) self.assertEqual(t1.ne(t2).long().sum(), 0)
class TestDiverseBeamSearch(unittest.TestCase):
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.6, 0.4],
[0.0, unk, 0.6, 0.4],
# sentence 2:
[0.25, unk, 0.35, 0.4],
[0.25, unk, 0.35, 0.4],
]),
# step 2:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
# sentence 2:
[0.9, unk, 0.1, 0.0],
[0.9, unk, 0.1, 0.0],
]),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_diverse_beam_search(self):
generator = SequenceGenerator(
[self.model], self.tgt_dict,
beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
)
hypos = generator.generate(self.src_tokens, self.src_lengths)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
self.assertHypoScore(hypos[0][1], [0.9, 0.6, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment