"vscode:/vscode.git/clone" did not exist on "4402879ee48dcff0f657738d8af5e35b266bd0ed"
Unverified Commit e3f028f3 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add TF whisper (#19378)



* simplify loop

* add featur extractor

* add model

* start conversion

* add dropout

* initial commit of test files

* copnversion for all models

* update processor for correct padding

* update feature extraction

* update integration test logits match

* fmnt: off for the logits

* on the fly mel bank

* small nit

* update test

* update tokenizer

* nit feature extraction

* update

* update tokenizer test

* adds logit processor and update tokenizer to get supress tokens

* style

* clean convert

* revert to original modeling tf utils

* Update

* update

* nit

* clean convert file

* update tests and nits

* quality

* slow generation test

* ffn_dim to allow customization

* update readme

* add to toctreee

* start fixing integration tests

* update tests and code

* fix feature extractor

* fix config tests common

* update code to fix tests

* fix feature exctractor

* nit feature extraction

* update test for new feature extractor

* style

* add absrtact

* large logits wioth custom decoder input ids

* wraap around is otrch available

* fix feature extractor

* correct logits for whisper small.en

* nit

* fix encoder_attentino_mask

* some fixes

* remove unnecessary inputs

* nits

* add normalizer file

* update etst tokenization

* fix attention mask not defined

* fix generate

* remove uncoder attention mask useless

* update test modeling whisper

* update condfig to add second non supress tokens

* nits on feature exrtactor

* nit for test tokenizers

* update etsts

* update tests

* update tokenization test

* fixup

* invalidated hf token. Clean convert openai to whisper

* fix logit tests

* fixup

* Add model to README

* Fix doc tests

* clean merge

* revert toc_tree changes

* remove useless LogitProcessor

* Update whisper .mdx

* update config file doc

* update configuration docstring

* update test tokenization

* update test tokenization

* update tokenization whisper
Added copied from where needed

* update feature extraction

* nit test name

* style

* quality

* remove get suppress tokens and update non_speech tokens global variables

* Update src/transformers/models/whisper/feature_extraction_whisper.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* clean modeling whisper and test
Removed the attention mask arguments that are deprecated

* fix large test

* Add multilingual audio test, and translate test

* style

* fix larg multilingual test

* nits

* add copied from for attention layer

* remove attention masks in doc

* add english normalizer

* Update docs/source/en/model_doc/whisper.mdx
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update tokenization test

* remove copied from in whisper attention : no bias in k_proj only

* wrap around dependencies in english normalizer

* style

* correct import generation logits

* for now, wrap feature extractor with torch

* remove torch depencies for feature extraction and style

* Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/whisper.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fixup

* nit

* update logitds

* style

* nit

* nits and fix final tests

* add `is_more_itertools_available` to utils

* quality

* add begin supress tokens, supress tokens to generate args and config

* clean supressTokensLogitProcessor in generation logits

* Nit naming

* add supressTokensAtBegin

* udpate tests, supress tokens to None or correct values

* nit and style

* update RAG to fit test and generate_logit

* add copy pasted statment on english normalizer

* add arguments to config_common_kwargs

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/generation_logits_process.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* revert changes based on reviews

* update doc and nits

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* more nits

* last nits

* update test configuration common

* add BART name in decoder attention mask documentation

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* style

* nit

* nit

* add english.json file to git

* nits on documentation

* nit

* nits

* last styling

* add main toctree file

* remove sentence piece dependency

* clean init file

* fix tokenizer that has no dependencies on sentencepiece

* update whisper init file, nit

* remove english.json file

* add get decoder prompt id

* All weights loading

* Remove hanging pdb

* Fixup and tidy up

* Use same copied from as PT model

* Remove whitespace changes

* Remove torch references

* Tie embeddings

* Remove logits processor input to generate

* Update logit values

* revert changes and add forced logit processor

* nit

* clean normalizer

* remove protected

* Add logit processors and update generation code & tests

* Some tidy up

* Update docstring

* update

* update based on review

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update to reflect changes on the PT model branch

* Tidy up

* Remove extra whitespace

* Fix test - make input ids small enough we can append

* Include upstream changes on main

* PR comments - add batch tests, remove comments & defaults

* Fix model output imports

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation_tf_logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/whisper/test_modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update docstring example

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove changes to adjust_logits_during_generation function

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Tidy up imports that don't require TF

* Update tests - skip and no more skip

* Update tests/generation/test_generation_tf_logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Add training flags

* Add (skipped) XLA generation tests

* Add embedding correctness test

* Add constant ids for generation tests

* Make logits finding a bit tidier

* Remove unused args

* xla generation enabled

* Don't skip XLA tests anymore

* Fix tests - add position ids to expected signature and update rag generation

* Undo method reorder

* Remove added whitespace

* Remove copy-paste gradient checkopint ref

* Remove

* Trigger CI - (issue with refs when pulling)
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarNielsRogge <niels.rogge1@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
parent af69360b
...@@ -330,7 +330,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -330,7 +330,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | ❌ | ✅ | | ❌ | | Whisper | ✅ | ❌ | ✅ | | ❌ |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ | | X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ | | XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
......
...@@ -27,7 +27,7 @@ Tips: ...@@ -27,7 +27,7 @@ Tips:
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference. - The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text. - One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
The original code can be found [here](https://github.com/openai/whisper). The original code can be found [here](https://github.com/openai/whisper).
...@@ -66,3 +66,14 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -66,3 +66,14 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] WhisperForConditionalGeneration [[autodoc]] WhisperForConditionalGeneration
- forward - forward
## TFWhisperModel
[[autodoc]] TFWhisperModel
- call
## TFWhisperForConditionalGeneration
[[autodoc]] TFWhisperForConditionalGeneration
- call
...@@ -2754,6 +2754,14 @@ else: ...@@ -2754,6 +2754,14 @@ else:
"TFWav2Vec2PreTrainedModel", "TFWav2Vec2PreTrainedModel",
] ]
) )
_import_structure["models.whisper"].extend(
[
"TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWhisperForConditionalGeneration",
"TFWhisperModel",
"TFWhisperPreTrainedModel",
]
)
_import_structure["models.xglm"].extend( _import_structure["models.xglm"].extend(
[ [
"TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -5303,6 +5311,12 @@ if TYPE_CHECKING: ...@@ -5303,6 +5311,12 @@ if TYPE_CHECKING:
TFWav2Vec2Model, TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel, TFWav2Vec2PreTrainedModel,
) )
from .models.whisper import (
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWhisperForConditionalGeneration,
TFWhisperModel,
TFWhisperPreTrainedModel,
)
from .models.xglm import ( from .models.xglm import (
TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXGLMForCausalLM, TFXGLMForCausalLM,
......
...@@ -504,3 +504,84 @@ class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor): ...@@ -504,3 +504,84 @@ class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
axis=-1, axis=-1,
) )
return scores return scores
class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
r"""
[`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
sampled at the begining of the generation.
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = tf.cond(
tf.equal(cur_len, self.begin_index),
lambda: tf.tensor_scatter_nd_update(
scores,
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
),
lambda: scores,
)
return scores
class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
are not sampled."""
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = tf.tensor_scatter_nd_update(
scores,
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
)
return scores
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all
other tokens to `-inf` so that they are sampled at their corresponding index."""
def __init__(self, force_token_map):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
# Indexes without forced tokens will have an negative value.
force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
for index, token in force_token_map.items():
force_token_array[index] = token
self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
updates = tf.zeros((batch_size,), dtype=scores.dtype)
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
return new_scores
scores = tf.cond(
tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
# If the current length is geq than the length of force_token_array, the processor does nothing.
lambda: tf.identity(scores),
# Otherwise, it may force a certain token.
lambda: tf.cond(
tf.greater_equal(self.force_token_array[cur_len], 0),
# Only valid (positive) tokens are forced
lambda: _force_token(cur_len),
# Otherwise, the processor does nothing.
lambda: scores,
),
)
return scores
...@@ -26,11 +26,14 @@ from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice ...@@ -26,11 +26,14 @@ from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from .generation_tf_logits_process import ( from .generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor, TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
TFMinLengthLogitsProcessor, TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor, TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor, TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor, TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper, TFTemperatureLogitsWarper,
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
...@@ -401,6 +404,9 @@ class TFGenerationMixin: ...@@ -401,6 +404,9 @@ class TFGenerationMixin:
return_dict_in_generate=None, return_dict_in_generate=None,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
**model_kwargs, **model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r""" r"""
...@@ -494,6 +500,14 @@ class TFGenerationMixin: ...@@ -494,6 +500,14 @@ class TFGenerationMixin:
the target language token. the target language token.
forced_eos_token_id (`int`, *optional*): forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached. The id of the token to force as the last generated token when `max_length` is reached.
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens, before sampling.
model_specific_kwargs: model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. Additional model specific kwargs will be forwarded to the `forward` function of the model.
...@@ -609,6 +623,9 @@ class TFGenerationMixin: ...@@ -609,6 +623,9 @@ class TFGenerationMixin:
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id, forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id, forced_eos_token_id=forced_eos_token_id,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
**model_kwargs, **model_kwargs,
) )
...@@ -648,6 +665,12 @@ class TFGenerationMixin: ...@@ -648,6 +665,12 @@ class TFGenerationMixin:
forced_eos_token_id = ( forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
) )
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
begin_suppress_tokens = (
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
)
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
forced_decoder_ids = self.config.forced_decoder_ids
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -1368,6 +1391,9 @@ class TFGenerationMixin: ...@@ -1368,6 +1391,9 @@ class TFGenerationMixin:
return_dict_in_generate=None, return_dict_in_generate=None,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
suppress_tokens=None,
begin_suppress_tokens=None,
forced_decoder_ids=None,
**model_kwargs, **model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r""" r"""
...@@ -1461,6 +1487,15 @@ class TFGenerationMixin: ...@@ -1461,6 +1487,15 @@ class TFGenerationMixin:
the target language token. the target language token.
forced_eos_token_id (`int`, *optional*): forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached. The id of the token to force as the last generated token when `max_length` is reached.
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model. Additional model specific kwargs will be forwarded to the `call` function of the model.
...@@ -1695,12 +1730,16 @@ class TFGenerationMixin: ...@@ -1695,12 +1730,16 @@ class TFGenerationMixin:
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id, forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id, forced_eos_token_id=forced_eos_token_id,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
) )
# 9. go into different generation modes # 9. go into different generation modes
...@@ -1994,7 +2033,7 @@ class TFGenerationMixin: ...@@ -1994,7 +2033,7 @@ class TFGenerationMixin:
def _initialize_past(past, num_padding_values, batch_axis): def _initialize_past(past, num_padding_values, batch_axis):
"""initialize past with zeros -- the structure depends on `batch_axis`""" """initialize past with zeros -- the structure depends on `batch_axis`"""
if batch_axis == 0: if batch_axis == 0:
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2)) padding_values = tf.constant([[0, 0], [0, 0], [0, num_padding_values], [0, 0]], dtype=tf.int32)
new_past = () new_past = ()
for past_layer in past: for past_layer in past:
new_past_layer = list(past_layer) new_past_layer = list(past_layer)
...@@ -2099,12 +2138,16 @@ class TFGenerationMixin: ...@@ -2099,12 +2138,16 @@ class TFGenerationMixin:
self, self,
repetition_penalty: float, repetition_penalty: float,
no_repeat_ngram_size: int, no_repeat_ngram_size: int,
input_ids_seq_length: int,
bad_words_ids: List[List[int]], bad_words_ids: List[List[int]],
min_length: int, min_length: int,
max_length: int, max_length: int,
eos_token_id: int, eos_token_id: int,
forced_bos_token_id: int, forced_bos_token_id: int,
forced_eos_token_id: int, forced_eos_token_id: int,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
) -> TFLogitsProcessorList: ) -> TFLogitsProcessorList:
""" """
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
...@@ -2118,6 +2161,12 @@ class TFGenerationMixin: ...@@ -2118,6 +2161,12 @@ class TFGenerationMixin:
) )
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
begin_suppress_tokens = (
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
)
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
forced_decoder_ids = self.config.forced_decoder_ids
# instantiate processors list # instantiate processors list
if repetition_penalty is not None and repetition_penalty != 1.0: if repetition_penalty is not None and repetition_penalty != 1.0:
...@@ -2132,7 +2181,16 @@ class TFGenerationMixin: ...@@ -2132,7 +2181,16 @@ class TFGenerationMixin:
processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id)) processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None: if forced_eos_token_id is not None:
processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if suppress_tokens is not None:
processors.append(TFSuppressTokensLogitsProcessor(suppress_tokens))
if begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
if forced_decoder_ids is not None:
begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
processors.append(TFSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
if forced_decoder_ids is not None:
processors.append(TFForceTokensLogitsProcessor(forced_decoder_ids))
return processors return processors
def greedy_search( def greedy_search(
......
...@@ -80,6 +80,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -80,6 +80,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("vit", "TFViTModel"), ("vit", "TFViTModel"),
("vit_mae", "TFViTMAEModel"), ("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("whisper", "TFWhisperModel"),
("xglm", "TFXGLMModel"), ("xglm", "TFXGLMModel"),
("xlm", "TFXLMModel"), ("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"), ("xlm-roberta", "TFXLMRobertaModel"),
...@@ -145,6 +146,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -145,6 +146,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("t5", "TFT5ForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"), ("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"), ("transfo-xl", "TFTransfoXLLMHeadModel"),
("whisper", "TFWhisperForConditionalGeneration"),
("xlm", "TFXLMWithLMHeadModel"), ("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"), ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"), ("xlnet", "TFXLNetLMHeadModel"),
...@@ -253,6 +255,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -253,6 +255,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[ [
("speech_to_text", "TFSpeech2TextForConditionalGeneration"), ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("whisper", "TFWhisperForConditionalGeneration"),
] ]
) )
......
...@@ -1262,6 +1262,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1262,6 +1262,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
) )
model_kwargs["attention_mask"] = context_attention_mask model_kwargs["attention_mask"] = context_attention_mask
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -41,6 +41,18 @@ else: ...@@ -41,6 +41,18 @@ else:
"WhisperPreTrainedModel", "WhisperPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_whisper"] = [
"TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWhisperForConditionalGeneration",
"TFWhisperModel",
"TFWhisperPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig
...@@ -61,6 +73,19 @@ if TYPE_CHECKING: ...@@ -61,6 +73,19 @@ if TYPE_CHECKING:
WhisperPreTrainedModel, WhisperPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_whisper import (
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWhisperForConditionalGeneration,
TFWhisperModel,
TFWhisperPreTrainedModel,
)
else: else:
import sys import sys
......
...@@ -218,7 +218,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -218,7 +218,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
return_attention_mask: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length", padding: Optional[str] = "max_length",
max_length: Optional[int] = None, max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs **kwargs
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -262,19 +261,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -262,19 +261,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
The value that is used to fill the padding values / vectors. The value that is used to fill the padding values / vectors.
""" """
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool( is_batched = bool(
isinstance(raw_speech, (list, tuple)) isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
......
This diff is collapsed.
...@@ -2394,6 +2394,30 @@ class TFWav2Vec2PreTrainedModel(metaclass=DummyObject): ...@@ -2394,6 +2394,30 @@ class TFWav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWhisperModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWhisperPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -29,11 +29,14 @@ if is_tf_available(): ...@@ -29,11 +29,14 @@ if is_tf_available():
from transformers.generation_tf_logits_process import ( from transformers.generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor, TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
TFMinLengthLogitsProcessor, TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor, TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor, TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor, TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper, TFTemperatureLogitsWarper,
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
...@@ -331,6 +334,86 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -331,6 +334,86 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len) scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
@parameterized.expand([(False,), (True,)])
def test_suppress_tokens_at_begin_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
begin_suppress_tokens = [1, 2, 3]
begin_index = 5
logits_processor = TFSuppressTokensAtBeginLogitsProcessor(
begin_suppress_tokens=begin_suppress_tokens, begin_index=begin_index
)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# Check that no scores are suppressed if begin_index is not reached
cur_len = 4
input_ids = tf.convert_to_tensor([[11, 17, 15, 8], [14, 0, 19, 5], [13, 11, 18, 19], [11, 12, 16, 15]])
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
# Check that scores are suppressed if begin_index is reached
cur_len = 5
input_ids = tf.convert_to_tensor([[5, 5, 5, 0, 17], [18, 1, 9, 14, 17], [18, 6, 8, 15, 19], [8, 12, 17, 1, 2]])
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, begin_suppress_tokens, axis=1))))
@parameterized.expand([(False,), (True,)])
def test_suppress_tokens_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
suppress_tokens = [1, 3, 5]
keep_tokens = [i for i in range(vocab_size) if i not in suppress_tokens]
logits_processor = TFSuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# Check that suppress_tokens are suppressed and others are not
cur_len = 5
input_ids = tf.convert_to_tensor([[0, 10, 19, 6, 3], [17, 4, 8, 17, 2], [7, 1, 11, 6, 15], [5, 8, 13, 16, 0]])
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, suppress_tokens, axis=1))))
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(tf.gather(scores, keep_tokens, axis=1))))
@parameterized.expand([(False,), (True,)])
def test_force_tokens_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
force_token_map = {1: 2, 3: 2}
logits_processor = TFForceTokensLogitsProcessor(force_token_map=force_token_map)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# check that if the cur_len is contained in the force_token_map, the logits are the same
# for all tokens except the one the force_token_map points to
cur_len = 1
input_ids = tf.convert_to_tensor([[11], [7], [5], [15]])
ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
tf.debugging.assert_near(tf.gather(scores, [force_token_map[cur_len]], axis=1), 0.0)
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
self.assertTrue(
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
)
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
cur_len = 2
input_ids = tf.convert_to_tensor([[2, 19], [19, 15], [4, 9], [7, 6]])
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_processor_list(self, use_xla): def test_processor_list(self, use_xla):
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA # TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
......
This diff is collapsed.
...@@ -736,6 +736,23 @@ class TFModelTesterMixin: ...@@ -736,6 +736,23 @@ class TFModelTesterMixin:
dtype="float32", dtype="float32",
), ),
} }
elif model_class.__name__ in ["TFWhisperModel", "TFWhisperForConditionalGeneration"]:
inputs = {
"decoder_input_ids": tf.keras.Input(
batch_shape=(2, max_input),
name="decoder_input_ids",
dtype="int32",
),
"input_features": tf.keras.Input(
batch_shape=(
2,
self.model_tester.num_mel_bins,
self.model_tester.seq_length,
),
name="input_features",
dtype="float32",
),
}
elif self.is_encoder_decoder: elif self.is_encoder_decoder:
inputs = { inputs = {
"decoder_input_ids": tf.keras.Input( "decoder_input_ids": tf.keras.Input(
...@@ -1223,8 +1240,17 @@ class TFModelTesterMixin: ...@@ -1223,8 +1240,17 @@ class TFModelTesterMixin:
# fetch the output for an input exclusively made of new members of the vocabulary # fetch the output for an input exclusively made of new members of the vocabulary
inputs_dict = copy.deepcopy(original_inputs_dict) inputs_dict = copy.deepcopy(original_inputs_dict)
new_vocab_input_ids = ids_tensor(inputs_dict["input_ids"].shape, new_tokens_size) ids_feat_name = None
if "input_ids" in inputs_dict:
ids_feat_name = "input_ids"
elif "decoder_input_ids" in inputs_dict:
ids_feat_name = "decoder_input_ids"
else:
assert False, "No input ids feature found in the inputs dict"
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
new_vocab_input_ids += old_total_size new_vocab_input_ids += old_total_size
inputs_dict[ids_feat_name] = new_vocab_input_ids
if "input_ids" in inputs_dict: if "input_ids" in inputs_dict:
inputs_dict["input_ids"] = new_vocab_input_ids inputs_dict["input_ids"] = new_vocab_input_ids
if "decoder_input_ids" in inputs_dict: if "decoder_input_ids" in inputs_dict:
......
...@@ -105,6 +105,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -105,6 +105,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix "TFRobertaForMultipleChoice", # TODO: fix
"TrOCRDecoderWrapper", # Building part of bigger (tested) model. "TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"TFWhisperEncoder", # Building part of bigger (tested) model.
"TFWhisperDecoder", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model. "SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model. "FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
......
...@@ -97,4 +97,5 @@ src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py ...@@ -97,4 +97,5 @@ src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
src/transformers/models/wavlm/modeling_wavlm.py src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/whisper/modeling_whisper.py src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_tf_whisper.py
src/transformers/models/yolos/modeling_yolos.py src/transformers/models/yolos/modeling_yolos.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment