"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d0d1632958c7d543e07afc672a8501d704e5a65f"
Unverified Commit b0e0ac8a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Fix output scores greedy search (#17442)

parent 2ef09ecf
...@@ -1689,10 +1689,13 @@ class GenerationMixin: ...@@ -1689,10 +1689,13 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores: if output_scores:
scores += (next_token_logits,) scores += (next_tokens_scores,)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
...@@ -1707,9 +1710,6 @@ class GenerationMixin: ...@@ -1707,9 +1710,6 @@ class GenerationMixin:
else (outputs.hidden_states,) else (outputs.hidden_states,)
) )
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax # argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1) next_tokens = torch.argmax(next_tokens_scores, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment