Commit 96ac28d3 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix and generalize --temperature option (#508)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/508

The previous version applied the temperature after the softmax. Fix that, and
also generalize so it works with other search approaches.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/694

Differential Revision: D15175160

Pulled By: myleott

fbshipit-source-id: cc87ff0e97a8a1dd37f9983163f58a8641155ab0
parent fc1a19a3
...@@ -422,8 +422,8 @@ def add_generation_args(parser): ...@@ -422,8 +422,8 @@ def add_generation_args(parser):
help='sample hypotheses instead of using beam search') help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for random sampling') help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search') help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N', group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
......
...@@ -168,10 +168,9 @@ class DiverseBeamSearch(Search): ...@@ -168,10 +168,9 @@ class DiverseBeamSearch(Search):
class Sampling(Search): class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1, sampling_temperature=1.): def __init__(self, tgt_dict, sampling_topk=-1):
super().__init__(tgt_dict) super().__init__(tgt_dict)
self.sampling_topk = sampling_topk self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def step(self, step, lprobs, scores): def step(self, step, lprobs, scores):
super()._init_buffers(lprobs) super()._init_buffers(lprobs)
...@@ -190,10 +189,6 @@ class Sampling(Search): ...@@ -190,10 +189,6 @@ class Sampling(Search):
if self.sampling_topk > 0: if self.sampling_topk > 0:
lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk) lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk)
# sampling temperature
if self.sampling_temperature != 1.:
lprobs_nopad = lprobs_nopad.div_(self.sampling_temperature)
# sample # sample
probs_nopad = lprobs_nopad.exp_() probs_nopad = lprobs_nopad.exp_()
if step == 0: if step == 0:
......
...@@ -28,7 +28,7 @@ class SequenceGenerator(object): ...@@ -28,7 +28,7 @@ class SequenceGenerator(object):
retain_dropout=False, retain_dropout=False,
sampling=False, sampling=False,
sampling_topk=-1, sampling_topk=-1,
sampling_temperature=1., temperature=1.,
diverse_beam_groups=-1, diverse_beam_groups=-1,
diverse_beam_strength=0.5, diverse_beam_strength=0.5,
match_source_len=False, match_source_len=False,
...@@ -58,9 +58,9 @@ class SequenceGenerator(object): ...@@ -58,9 +58,9 @@ class SequenceGenerator(object):
(default: False) (default: False)
sampling_topk (int, optional): only sample among the top-k choices sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1) at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling, temperature (float, optional): temperature, where values
where values >1.0 produces more uniform sampling and values >1.0 produce more uniform samples and values <1.0 produce
<1.0 produces sharper sampling (default: 1.0) sharper samples (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source match_source_len (bool, optional): outputs should match the source
...@@ -81,13 +81,15 @@ class SequenceGenerator(object): ...@@ -81,13 +81,15 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.temperature = temperature
self.match_source_len = match_source_len self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert temperature > 0, '--temperature must be greater than 0'
if sampling: if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature) self.search = search.Sampling(tgt_dict, sampling_topk)
elif diverse_beam_groups > 0: elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len: elif match_source_len:
...@@ -304,7 +306,9 @@ class SequenceGenerator(object): ...@@ -304,7 +306,9 @@ class SequenceGenerator(object):
model.reorder_incremental_state(reorder_state) model.reorder_incremental_state(reorder_state)
model.reorder_encoder_out(encoder_outs, reorder_state) model.reorder_encoder_out(encoder_outs, reorder_state)
lprobs, avg_attn_scores = model.forward_decoder(tokens[:, :step + 1], encoder_outs) lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
)
lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
...@@ -547,7 +551,7 @@ class EnsembleModel(torch.nn.Module): ...@@ -547,7 +551,7 @@ class EnsembleModel(torch.nn.Module):
return [model.encoder(**encoder_input) for model in self.models] return [model.encoder(**encoder_input) for model in self.models]
@torch.no_grad() @torch.no_grad()
def forward_decoder(self, tokens, encoder_outs): def forward_decoder(self, tokens, encoder_outs, temperature=1.):
if len(self.models) == 1: if len(self.models) == 1:
return self._decode_one( return self._decode_one(
tokens, tokens,
...@@ -555,12 +559,20 @@ class EnsembleModel(torch.nn.Module): ...@@ -555,12 +559,20 @@ class EnsembleModel(torch.nn.Module):
encoder_outs[0] if self.has_encoder() else None, encoder_outs[0] if self.has_encoder() else None,
self.incremental_states, self.incremental_states,
log_probs=True, log_probs=True,
temperature=temperature,
) )
log_probs = [] log_probs = []
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, self.incremental_states, log_probs=True) probs, attn = self._decode_one(
tokens,
model,
encoder_out,
self.incremental_states,
log_probs=True,
temperature=temperature,
)
log_probs.append(probs) log_probs.append(probs)
if attn is not None: if attn is not None:
if avg_attn is None: if avg_attn is None:
...@@ -572,12 +584,17 @@ class EnsembleModel(torch.nn.Module): ...@@ -572,12 +584,17 @@ class EnsembleModel(torch.nn.Module):
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs): def _decode_one(
self, tokens, model, encoder_out, incremental_states, log_probs,
temperature=1.,
):
if self.incremental_states is not None: if self.incremental_states is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model])) decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model]))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :] decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1] attn = decoder_out[1]
if type(attn) is dict: if type(attn) is dict:
attn = attn['attn'] attn = attn['attn']
......
...@@ -197,7 +197,7 @@ class FairseqTask(object): ...@@ -197,7 +197,7 @@ class FairseqTask(object):
unk_penalty=args.unkpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling=args.sampling,
sampling_topk=args.sampling_topk, sampling_topk=args.sampling_topk,
sampling_temperature=args.sampling_temperature, temperature=args.temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, match_source_len=args.match_source_len,
......
...@@ -94,7 +94,7 @@ class TestTranslation(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en') train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir, [ generate_main(data_dir, [
'--sampling', '--sampling',
'--sampling-temperature', '2', '--temperature', '2',
'--beam', '2', '--beam', '2',
'--nbest', '2', '--nbest', '2',
]) ])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment