Unverified Commit b0e0ac8a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Fix output scores greedy search (#17442)

parent 2ef09ecf
......@@ -1689,10 +1689,13 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_logits,)
scores += (next_tokens_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
......@@ -1707,9 +1710,6 @@ class GenerationMixin:
else (outputs.hidden_states,)
)
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment