Unverified Commit 4afead8a authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Deprecate forced ids for v4.39 (#29485)

deprecate old funcs
parent 9acce7de
...@@ -25,7 +25,6 @@ from torch import nn ...@@ -25,7 +25,6 @@ from torch import nn
from ...generation.configuration_utils import GenerationConfig from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ( from ...generation.logits_process import (
ForceTokensLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
SuppressTokensAtBeginLogitsProcessor, SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor, SuppressTokensLogitsProcessor,
...@@ -539,11 +538,9 @@ class WhisperGenerationMixin: ...@@ -539,11 +538,9 @@ class WhisperGenerationMixin:
num_segment_frames=num_segment_frames, num_segment_frames=num_segment_frames,
kwargs=kwargs, kwargs=kwargs,
) )
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
# This function should be be removed in v4.39 # where the input ids are handled explicitly by the generate method
self._check_decoder_input_ids( self._check_decoder_input_ids(kwargs=kwargs)
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
)
# 3. Retrieve logits processors # 3. Retrieve logits processors
begin_index = len(init_tokens) begin_index = len(init_tokens)
...@@ -1129,15 +1126,13 @@ class WhisperGenerationMixin: ...@@ -1129,15 +1126,13 @@ class WhisperGenerationMixin:
forced_decoder_ids = forced_decoder_ids[1:] forced_decoder_ids = forced_decoder_ids[1:]
i += 1 i += 1
# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
if len(forced_decoder_ids) > 0: if len(forced_decoder_ids) > 0:
warnings.warn( raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.", f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
FutureWarning,
) )
# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39 # from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None generation_config.forced_decoder_ids = None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None) is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
if language is not None: if language is not None:
...@@ -1282,20 +1277,12 @@ class WhisperGenerationMixin: ...@@ -1282,20 +1277,12 @@ class WhisperGenerationMixin:
return lang_ids return lang_ids
@staticmethod @staticmethod
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs): def _check_decoder_input_ids(kwargs):
decoder_input_ids = kwargs.get("decoder_input_ids", None) decoder_input_ids = kwargs.get("decoder_input_ids", None)
if prompt_ids is not None and decoder_input_ids is not None: assistant_model = kwargs.get("assistant_model", None)
if decoder_input_ids is not None and assistant_model is not None:
raise ValueError( raise ValueError(
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it." "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)
elif decoder_input_ids is not None and not is_shortform:
raise ValueError(
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
)
elif decoder_input_ids is not None and is_shortform:
warnings.warn(
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
FutureWarning,
) )
@staticmethod @staticmethod
...@@ -1436,19 +1423,6 @@ class WhisperGenerationMixin: ...@@ -1436,19 +1423,6 @@ class WhisperGenerationMixin:
) )
no_speech_detector.set_model(self) no_speech_detector.set_model(self)
if is_shortform and generation_config.forced_decoder_ids is not None:
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
# It's important that the `forced_tokens_proc` processor is appended after
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
# which would lead to unexpected behavior
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
# initialize all of them as `decoder_input_ids`.
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
logits_processor = (
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
)
generation_config.forced_decoder_ids = None
return logits_processor return logits_processor
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment