Unverified Commit 2beabd24 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[🛠️] Fix-whisper-breaking-changes (#21965)



* temp fix

* temporary fix

* update

* fix tests

* fixup

* update based on reveiew
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* update to fix tests

* update docstring

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 101a6cd2
......@@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch Whisper model."""
import math
import random
from typing import Optional, Tuple, Union
......@@ -37,6 +36,7 @@ from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
logger = logging.get_logger(__name__)
......@@ -1510,8 +1510,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
language tokens in the `model.generation_config.lang_to_id` dictionary.
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
kwargs:
......@@ -1543,39 +1543,63 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
generation_config = self.generation_config
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set."
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
generation_config.return_timestamps = return_timestamps
else:
generation_config.return_timestamps = False
if language is not None:
generation_config.language = language
if task is not None:
generation_config.task = task
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if task is not None or language is not None:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else:
forced_decoder_ids.append((1, None))
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment