"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "63a5399bc42fd4cae42f90354c5801354a2e30f6"
Unverified Commit 6701fb78 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix beam_search behavior when sampling (#3106)

* fix beam_search behavior when sampling

* delete print

* make correct style
parent e9e6efdc
...@@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.eval() model.eval()
if output_loading_info: if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs,
}
return model, loading_info return model, loading_info
return model return model
...@@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
...@@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
scores = scores / temperature scores = scores / temperature
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering # Top-p/top-k filtering
scores = top_k_top_p_filtering( _scores = top_k_top_p_filtering(
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2) next_words = torch.multinomial(
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores # Compute next scores
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
else: else:
# do greedy beam search # do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and word_id.item() in eos_token_ids: if eos_token_ids is not None and word_id.item() in eos_token_ids:
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted word if it is not eos_token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment