"...resnet50_tensorflow.git" did not exist on "b35d9c916f64fab83ded930510eba6c2c8cd7a9e"
Unverified Commit 4afead8a authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Deprecate forced ids for v4.39 (#29485)

deprecate old funcs
parent 9acce7de
......@@ -25,7 +25,6 @@ from torch import nn
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import (
ForceTokensLogitsProcessor,
LogitsProcessorList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
......@@ -539,11 +538,9 @@ class WhisperGenerationMixin:
num_segment_frames=num_segment_frames,
kwargs=kwargs,
)
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead
# This function should be be removed in v4.39
self._check_decoder_input_ids(
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
)
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
# where the input ids are handled explicitly by the generate method
self._check_decoder_input_ids(kwargs=kwargs)
# 3. Retrieve logits processors
begin_index = len(init_tokens)
......@@ -1129,15 +1126,13 @@ class WhisperGenerationMixin:
forced_decoder_ids = forced_decoder_ids[1:]
i += 1
# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
if len(forced_decoder_ids) > 0:
warnings.warn(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.",
FutureWarning,
raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
)
# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
if language is not None:
......@@ -1282,20 +1277,12 @@ class WhisperGenerationMixin:
return lang_ids
@staticmethod
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs):
def _check_decoder_input_ids(kwargs):
decoder_input_ids = kwargs.get("decoder_input_ids", None)
if prompt_ids is not None and decoder_input_ids is not None:
assistant_model = kwargs.get("assistant_model", None)
if decoder_input_ids is not None and assistant_model is not None:
raise ValueError(
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it."
)
elif decoder_input_ids is not None and not is_shortform:
raise ValueError(
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
)
elif decoder_input_ids is not None and is_shortform:
warnings.warn(
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
FutureWarning,
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)
@staticmethod
......@@ -1436,19 +1423,6 @@ class WhisperGenerationMixin:
)
no_speech_detector.set_model(self)
if is_shortform and generation_config.forced_decoder_ids is not None:
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
# It's important that the `forced_tokens_proc` processor is appended after
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
# which would lead to unexpected behavior
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
# initialize all of them as `decoder_input_ids`.
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
logits_processor = (
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
)
generation_config.forced_decoder_ids = None
return logits_processor
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment