Unverified Commit 0fced067 authored by BakerBunker's avatar BakerBunker Committed by GitHub
Browse files

Fix `beam_scores` shape when token scores shape changes after `logits_processor` (#25980)

parent a796f7ee
...@@ -3038,7 +3038,9 @@ class GenerationMixin: ...@@ -3038,7 +3038,9 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
...@@ -3363,7 +3365,9 @@ class GenerationMixin: ...@@ -3363,7 +3365,9 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
...@@ -4080,7 +4084,9 @@ class GenerationMixin: ...@@ -4080,7 +4084,9 @@ class GenerationMixin:
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
scores_for_all_vocab = next_token_scores.clone() scores_for_all_vocab = next_token_scores.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment