Unverified Commit f614b6e3 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Whisper: fix prompted max length (#24666)

parent 49572942
......@@ -6,7 +6,10 @@ from typing import Optional
import torch
from ..utils import add_start_docstrings
from ..utils import add_start_docstrings, logging
logger = logging.get_logger(__name__)
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......@@ -46,14 +49,25 @@ class MaxLengthCriteria(StoppingCriteria):
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
max_position_embeddings (`int`, `optional`):
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
"""
def __init__(self, max_length: int):
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
self.max_length = max_length
self.max_position_embeddings = max_position_embeddings
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return is_done
class MaxNewTokensCriteria(StoppingCriteria):
......
......@@ -954,7 +954,13 @@ class GenerationMixin:
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
if generation_config.max_length is not None:
criteria.append(MaxLengthCriteria(max_length=generation_config.max_length))
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
criteria.append(
MaxLengthCriteria(
max_length=generation_config.max_length,
max_position_embeddings=max_position_embeddings,
)
)
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
......
......@@ -1715,11 +1715,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
# Update the max generation length to include the prompt
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
default_max_length = generation_config.max_new_tokens or generation_config.max_length
non_prompt_max_length = specified_max_length or default_max_length
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment