Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0fced067
Unverified
Commit
0fced067
authored
Sep 14, 2023
by
BakerBunker
Committed by
GitHub
Sep 13, 2023
Browse files
Fix `beam_scores` shape when token scores shape changes after `logits_processor` (#25980)
parent
a796f7ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
3 deletions
+9
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-3
No files found.
src/transformers/generation/utils.py
View file @
0fced067
...
@@ -3038,7 +3038,9 @@ class GenerationMixin:
...
@@ -3038,7 +3038,9 @@ class GenerationMixin:
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
return_dict_in_generate
:
...
@@ -3363,7 +3365,9 @@ class GenerationMixin:
...
@@ -3363,7 +3365,9 @@ class GenerationMixin:
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
...
@@ -4080,7 +4084,9 @@ class GenerationMixin:
...
@@ -4080,7 +4084,9 @@ class GenerationMixin:
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
scores_for_all_vocab
=
next_token_scores
.
clone
()
scores_for_all_vocab
=
next_token_scores
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment