Unverified Commit 2beabd24 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[🛠️] Fix-whisper-breaking-changes (#21965)



* temp fix

* temporary fix

* update

* fix tests

* fixup

* update based on reveiew
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* update to fix tests

* update docstring

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 101a6cd2
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch Whisper model.""" """ PyTorch Whisper model."""
import math import math
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -37,6 +36,7 @@ from ...modeling_outputs import ( ...@@ -37,6 +36,7 @@ from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1510,8 +1510,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1510,8 +1510,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly. will be updated accordingly.
language (`bool`, *optional*): language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
language tokens in the `model.generation_config.lang_to_id` dictionary. find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*): is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual. Whether or not the model is multilingual.
kwargs: kwargs:
...@@ -1543,39 +1543,63 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1543,39 +1543,63 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
generation_config = self.generation_config generation_config = self.generation_config
if return_timestamps is not None: if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
if task is not None: "You are trying to return timestamps, but the generation config is not properly set."
generation_config.task = task "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
if is_multilingual is not None: generation_config.return_timestamps = return_timestamps
generation_config.is_multilingual = is_multilingual else:
generation_config.return_timestamps = False
if language is not None: if language is not None:
generation_config.language = language generation_config.language = language
if task is not None:
generation_config.task = task
forced_decoder_ids = [] forced_decoder_ids = []
if task is not None or language is not None:
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if hasattr(generation_config, "language"): if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else: else:
forced_decoder_ids.append((1, None)) raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"): if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else: else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
if ( )
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
else: else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if len(forced_decoder_ids) > 0: if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids generation_config.forced_decoder_ids = forced_decoder_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment