Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b0e0ac8a
Unverified
Commit
b0e0ac8a
authored
May 31, 2022
by
Patrick von Platen
Committed by
GitHub
May 31, 2022
Browse files
[Generate] Fix output scores greedy search (#17442)
parent
2ef09ecf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+4
-4
No files found.
src/transformers/generation_utils.py
View file @
b0e0ac8a
...
...
@@ -1689,10 +1689,13 @@ class GenerationMixin:
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_token
_logit
s
,)
scores
+=
(
next_token
s_score
s
,)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
...
...
@@ -1707,9 +1710,6 @@ class GenerationMixin:
else
(
outputs
.
hidden_states
,)
)
# pre-process distribution
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
# argmax
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment