Unverified Commit 3d7baef1 authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Whisper generate, move text_prompt_ids trim up for max_new_tokens calculation (#23724)

move text_prompt_ids trimming to top
parent 50a56bed
...@@ -1633,6 +1633,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1633,6 +1633,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
) )
prompt_ids = prompt_ids.tolist() prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids decoder_start_token_id, *text_prompt_ids = prompt_ids
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
# Set the decoder_start_token_id to <|startofprev|> # Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id}) kwargs.update({"decoder_start_token_id": decoder_start_token_id})
...@@ -1647,9 +1650,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1647,9 +1650,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
) )
forced_decoder_ids = [ forced_decoder_ids = [
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation *text_prompt_ids,
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
generation_config.decoder_start_token_id, generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids], *[token for _rank, token in non_prompt_forced_decoder_ids],
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment