Commit e46b924d authored by Xing Zhou's avatar Xing Zhou Committed by Facebook Github Bot
Browse files

Nucleus (top-P) sampling (#710)

Summary:
Implement Nucleus (top-P) sampling: sample among the smallest set of elements whose cumulative probability mass exceeds p.

To test it:
python generate.py   ~myleott/data/data-bin/wmt17_zh_en_full/   --path ~myleott/zh_en/model.pt   --remove-bpe   --nbest 5   --beam 5 --sampling --sampling-topp 0.3
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/710

Test Plan:
python generate.py   ~myleott/data/data-bin/wmt17_zh_en_full/   --path ~myleott/zh_en/model.pt   --remove-bpe   --nbest 5   --beam 5 --sampling --sampling-topp 0.3

python tests/test_sequence_generator.py

python tests/test_binaries.py

Reviewed By: myleott

Differential Revision: D16286688

Pulled By: xingz9

fbshipit-source-id: 1776d21e17c4532a3d24ac75bb7e75da9acad58f
parent 473389a3
...@@ -472,6 +472,8 @@ def add_generation_args(parser): ...@@ -472,6 +472,8 @@ def add_generation_args(parser):
help='sample hypotheses instead of using beam search') help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
group.add_argument('--temperature', default=1., type=float, metavar='N', group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation') help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
......
...@@ -168,9 +168,54 @@ class DiverseBeamSearch(Search): ...@@ -168,9 +168,54 @@ class DiverseBeamSearch(Search):
class Sampling(Search): class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1): def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
super().__init__(tgt_dict) super().__init__(tgt_dict)
self.sampling_topk = sampling_topk self.sampling_topk = sampling_topk
self.sampling_topp = sampling_topp
def _sample_topp(self, lprobs):
"""Sample among the smallest set of elements whose cumulative probability mass exceeds p.
See `"The Curious Case of Neural Text Degeneration"
(Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
Args:
lprobs: (bsz x input_beam_size x vocab_size)
the model's log-probabilities over the vocabulary at the current step
Return: A tuple of (trimed_probs, truncated_indices) where:
trimed_probs: (bsz x input_beam_size x ?)
the model's probabilities over the elements selected to sample from. The
width of the third dimension is determined by top-P.
truncated_indices: (bsz x input_beam_size x ?)
the indices of the chosen elements.
"""
probs = lprobs.exp_()
# sort the last dimension (vocab dimension) in descending order
sorted_probs, sorted_indices = probs.sort(descending=True)
# compute a mask to indicate the words to be included in the top-P set.
cumsum_probs = sorted_probs.cumsum(dim=2)
mask = cumsum_probs.lt(self.sampling_topp)
# note that mask was computed by 'lt'. One more word needs to be included
# so that the cumulative probability mass can exceed p.
cumsum_mask = mask.cumsum(dim=2)
last_included = cumsum_mask[:, :, :1]
mask = mask.scatter_(2, last_included, 1)
# truncate unnecessary dims.
max_dim = last_included.max()
truncated_mask = mask[:, :, :max_dim + 1]
truncated_probs = sorted_probs[:, :, :max_dim + 1]
truncated_indices = sorted_indices[:, :, :max_dim + 1]
# trim the words that are not in top-P by setting their probabilities
# to 0, so that they would not be sampled later.
trim_mask = 1 - truncated_mask
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
return trimed_probs, truncated_indices
def step(self, step, lprobs, scores): def step(self, step, lprobs, scores):
super()._init_buffers(lprobs) super()._init_buffers(lprobs)
...@@ -185,12 +230,17 @@ class Sampling(Search): ...@@ -185,12 +230,17 @@ class Sampling(Search):
assert self.pad <= 1, 'sampling assumes the first two symbols can be ignored' assert self.pad <= 1, 'sampling assumes the first two symbols can be ignored'
lprobs_nopad = lprobs[:, :, 2:] lprobs_nopad = lprobs[:, :, 2:]
# only sample from top-k candidates if self.sampling_topp > 0:
if self.sampling_topk > 0: # only sample from the smallest set of words whose cumulative probability mass exceeds p
lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk) probs_nopad, top_indices = self._sample_topp(lprobs_nopad)
elif self.sampling_topk > 0:
# only sample from top-k candidates
lprobs_nopad, top_indices = lprobs_nopad.topk(self.sampling_topk)
probs_nopad = lprobs_nopad.exp_()
else:
probs_nopad = lprobs_nopad.exp_()
# sample # sample
probs_nopad = lprobs_nopad.exp_()
if step == 0: if step == 0:
self.indices_buf = torch.multinomial( self.indices_buf = torch.multinomial(
probs_nopad.view(bsz, -1), probs_nopad.view(bsz, -1),
...@@ -219,10 +269,10 @@ class Sampling(Search): ...@@ -219,10 +269,10 @@ class Sampling(Search):
) )
self.scores_buf = self.scores_buf.log_().view(bsz, -1) self.scores_buf = self.scores_buf.log_().view(bsz, -1)
# remap indices if using top-k sampling # remap indices if using top-k or top-P sampling
if self.sampling_topk > 0: if self.sampling_topk > 0 or self.sampling_topp > 0:
self.indices_buf = torch.gather( self.indices_buf = torch.gather(
topk_indices.expand(bsz, beam_size, -1), top_indices.expand(bsz, beam_size, -1),
dim=2, dim=2,
index=self.indices_buf.unsqueeze(-1), index=self.indices_buf.unsqueeze(-1),
).squeeze(2) ).squeeze(2)
......
...@@ -28,6 +28,7 @@ class SequenceGenerator(object): ...@@ -28,6 +28,7 @@ class SequenceGenerator(object):
retain_dropout=False, retain_dropout=False,
sampling=False, sampling=False,
sampling_topk=-1, sampling_topk=-1,
sampling_topp=-1.0,
temperature=1., temperature=1.,
diverse_beam_groups=-1, diverse_beam_groups=-1,
diverse_beam_strength=0.5, diverse_beam_strength=0.5,
...@@ -58,6 +59,9 @@ class SequenceGenerator(object): ...@@ -58,6 +59,9 @@ class SequenceGenerator(object):
(default: False) (default: False)
sampling_topk (int, optional): only sample among the top-k choices sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1) at each step (default: -1)
sampling_topp (float, optional): only sample among the smallest set
of words whose cumulative probability mass exceeds p
at each step (default: -1.0)
temperature (float, optional): temperature, where values temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce >1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0) sharper samples (default: 1.0)
...@@ -86,10 +90,11 @@ class SequenceGenerator(object): ...@@ -86,10 +90,11 @@ class SequenceGenerator(object):
self.no_repeat_ngram_size = no_repeat_ngram_size self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
assert temperature > 0, '--temperature must be greater than 0' assert temperature > 0, '--temperature must be greater than 0'
if sampling: if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk) self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp)
elif diverse_beam_groups > 0: elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len: elif match_source_len:
......
...@@ -201,6 +201,7 @@ class FairseqTask(object): ...@@ -201,6 +201,7 @@ class FairseqTask(object):
unk_penalty=getattr(args, 'unkpen', 0), unk_penalty=getattr(args, 'unkpen', 0),
sampling=getattr(args, 'sampling', False), sampling=getattr(args, 'sampling', False),
sampling_topk=getattr(args, 'sampling_topk', -1), sampling_topk=getattr(args, 'sampling_topk', -1),
sampling_topp=getattr(args, 'sampling_topp', -1.0),
temperature=getattr(args, 'temperature', 1.), temperature=getattr(args, 'temperature', 1.),
diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1), diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
......
...@@ -104,6 +104,12 @@ class TestTranslation(unittest.TestCase): ...@@ -104,6 +104,12 @@ class TestTranslation(unittest.TestCase):
'--beam', '2', '--beam', '2',
'--nbest', '2', '--nbest', '2',
]) ])
generate_main(data_dir, [
'--sampling',
'--sampling-topp', '0.2',
'--beam', '2',
'--nbest', '2',
])
generate_main(data_dir, ['--prefix-size', '2']) generate_main(data_dir, ['--prefix-size', '2'])
def test_lstm(self): def test_lstm(self):
......
...@@ -15,7 +15,30 @@ from fairseq.sequence_generator import SequenceGenerator ...@@ -15,7 +15,30 @@ from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils import tests.utils as test_utils
class TestSequenceGenerator(unittest.TestCase): class TestSequenceGeneratorBase(unittest.TestCase):
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
class TestSequenceGenerator(TestSequenceGeneratorBase):
def setUp(self): def setUp(self):
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = ( self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
...@@ -133,28 +156,8 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -133,28 +156,8 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos]) self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0]) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): class TestDiverseBeamSearch(TestSequenceGeneratorBase):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
class TestDiverseBeamSearch(unittest.TestCase):
def setUp(self): def setUp(self):
# construct dummy dictionary # construct dummy dictionary
...@@ -232,25 +235,156 @@ class TestDiverseBeamSearch(unittest.TestCase): ...@@ -232,25 +235,156 @@ class TestDiverseBeamSearch(unittest.TestCase):
self.assertHypoTokens(hypos[1][1], [w1, w2, eos]) self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9]) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): class TestTopPSamplingSearch(TestSequenceGeneratorBase):
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
# The minimal probability of top 2 tokens.
self.min_top2_prob = 0.75
# The minimal probability of the top 1 token.
self.min_top1_prob = 0.4
w1_prob = self.min_top1_prob
w2_prob = self.min_top2_prob - self.min_top1_prob
eos_prob = 1 - self.min_top2_prob
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
]),
# step 2:
torch.FloatTensor([
# eos w1 w2
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
]),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_topp_sampling_search_low_prob(self):
# Given a prob low enough to top-P sampling, we expect only the top
# 1 token to be sampled, which always results in the same output.
low_sampling_topp = self.min_top1_prob/2.0
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=low_sampling_topp
)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
eos, w1 = self.eos, self.w1
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])
def test_topp_sampling_search_high_prob(self):
# Given a prob high enough to top-P sampling, any of the top 2
# tokens could be sampled. This can cause different outputs.
high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=high_sampling_topp
)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or
self.hypoTokens(hypos[0][0], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0]))
# sentence 1, beam 2
self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or
self.hypoTokens(hypos[0][1], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0]))
# sentence 2, beam 1
self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or
self.hypoTokens(hypos[1][0], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0]))
# sentence 2, beam 2
self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or
self.hypoTokens(hypos[1][1], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0]))
def hypoTokens(self, hypo, tokens):
return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log() pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores) if not self.almostEqual(hypo['positional_scores'], pos_scores):
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel()) return False
if pos_scores.numel() != hypo['tokens'].numel():
return False
score = pos_scores.sum() score = pos_scores.sum()
if normalized: if normalized:
score /= pos_scores.numel()**lenpen score /= pos_scores.numel() ** lenpen
self.assertLess(abs(score - hypo['score']), 1e-6) return abs(score - hypo['score']) < 1e-6
def assertAlmostEqual(self, t1, t2): def almostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2): def tensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment