"docs/source/vscode:/vscode.git/clone" did not exist on "0ba94aceb6e1ab448e0acc896764a4496759cb14"
Unverified Commit 2840272c authored by Andy Ehrenberg's avatar Andy Ehrenberg Committed by GitHub
Browse files

add flax whisper implementation (#20479)



* add flax whisper implementation

* rever change to setup

* remove unused imports

* revert generation changes

* flax whisper docs

* docs

* import order

* import sorting

* isort

* add dummy objects

* doc formatting

* formatting

* remove trailing whitespaces

* fix flax whisper docs

* add generation logic to unlock flax whisper

* remove scans

* give credits to Flax Bart implementation

* remove unused imports

* add license

* remove assert

* more credits to Bart

* fix style

* formatting

* support left padding

* add flax whisper generation test

* remove copied from comments whenever not a full copy

* fix docstrings for logits processors

* revert change to FlaxForceTokensLogitsProcessor

* revert doc changes

* improve generation docs

* reorganize

* formatting

* cleanup docs

* add tests

* handle empty list case

* fix forced decoder ids in flax tests

* add flax whisper to inits

* upate dummy objects

* docs for FlaxAutoModelForSpeechSeq2Seq

* fix decoder_position_ids computation in pretrained model decode/__call__ fns

* add Copied from statements as necessary

* compute position_ids only in __call__ and decode methods of pretrained model subclasses

* improve readabilityof compute positional embeddings

* check dimensionality of input_features instead of hidden_states

* copied from statement for init_cache

* formatting

* fix copies

* fix copies

* pass attention mask to encoder layers

* fix decoder module outputs

* set dtype
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* smaller flax model for whisper test

* Update src/transformers/generation/flax_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/models/whisper/test_modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cleanup
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* bias cleanup

* doc fix

* align style for force tokens processor

* readability

* fix input shape in tests

* revert FlaxGenerationMixin docstring

* formatting

* fix tests

* fix imports

* consistent encoder hidden states

* consistent hidden states

* input shapes

* typo

* partial class trick

* partial class for input shape

* base_class with correct input shape

* partial base classes

* match by name

* set main_input_name

* compare on names

* formatting

* remove unused import

* safer position ids computation

* safer position id computation

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove identical inherited tests

* fix prompt ids in tests

* use generation config

* use jnp array

* better var names

* more explicit bias use

* import transformers

* formatting

* test formatting

* remove unused imports

* remove unused imports

* formatting

* isort

* docs

* fix ln orders for encoder hidden states

* whisper unique generation stuff

* flake

* use finfo for attention bias

* docs

* Update src/transformers/generation/flax_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* docs

* add timestamp flax test

* jit for timestamps

* formatting

* clean up timestamps processor

* formatting

* remove if_true

* cleanup

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 7735e040
......@@ -402,7 +402,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
......
......@@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks.
[[autodoc]] TFAutoModelForSpeechSeq2Seq
### FlaxAutoModelForSpeechSeq2Seq
[[autodoc]] FlaxAutoModelForSpeechSeq2Seq
### AutoModelForAudioXVector
[[autodoc]] AutoModelForAudioXVector
......
......@@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] TFWhisperForConditionalGeneration
- call
## FlaxWhisperModel
[[autodoc]] FlaxWhisperModel
- __call__
## FlaxWhisperForConditionalGeneration
[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__
......@@ -3382,6 +3382,7 @@ else:
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
......@@ -3395,6 +3396,7 @@ else:
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForSpeechSeq2Seq",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
......@@ -3578,6 +3580,13 @@ else:
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
_import_structure["models.whisper"].extend(
[
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
)
_import_structure["models.xglm"].extend(
[
"FlaxXGLMForCausalLM",
......@@ -6381,6 +6390,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
......@@ -6394,6 +6404,7 @@ if TYPE_CHECKING:
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForSpeechSeq2Seq,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
......@@ -6529,6 +6540,7 @@ if TYPE_CHECKING:
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -264,3 +264,191 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
return scores
class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
`begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
begining of the generation.
Args:
begin_suppress_tokens (`List[int]`):
Tokens to not sample.
begin_index (`int`):
Index where the tokens are suppressed.
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids, scores, cur_len: int):
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
return scores
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
to be `-inf` so they are not sampled.
Args:
suppress_tokens (`list`):
Tokens to not sample.
"""
def __init__(self, suppress_tokens: list):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
return scores
class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
to `-inf` so that they are sampled at their corresponding index.
Args:
force_token_map (`list`):
Map giving token ids and indices where they will be forced to be sampled.
"""
def __init__(self, force_token_map):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
# Indexes without forced tokens will have a negative value.
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
for index, token in force_token_map.items():
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
return new_scores
scores = lax.cond(
cur_len >= self.force_token_array.shape[0],
# If the current length is geq than the length of force_token_array, the processor does nothing.
lambda: scores,
# Otherwise, it may force a certain token.
lambda: lax.cond(
self.force_token_array[cur_len] >= 0,
# Only valid (positive) tokens are forced
lambda: _force_token(cur_len),
# Otherwise, the processor does nothing.
lambda: scores,
),
)
return scores
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.
Args:
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""
def __init__(self, generate_config, model_config, decoder_input_length):
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = decoder_input_length + 1
if generate_config.is_multilingual:
# room for language token and task token
self.begin_index += 2
if hasattr(generate_config, "max_initial_timestamp_index"):
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
else:
self.max_initial_timestamp_index = model_config.vocab_size
if self.max_initial_timestamp_index is None:
self.max_initial_timestamp_index = model_config.vocab_size
def __call__(self, input_ids, scores, cur_len):
# suppress <|notimestamps|> which is handled by without_timestamps
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
def handle_pairs(input_ids_k, scores_k):
last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
last_was_timestamp = jnp.where(
input_ids_k[cur_len - 1] >= self.timestamp_begin,
True and last_was_timestamp,
False,
)
penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
penultimate_was_timestamp = jnp.where(
input_ids_k[cur_len - 2] >= self.timestamp_begin,
True,
penultimate_was_timestamp,
)
return jnp.where(
last_was_timestamp,
jnp.where(
penultimate_was_timestamp > 0,
scores_k.at[self.timestamp_begin :].set(-float("inf")),
scores_k.at[: self.eos_token_id].set(-float("inf")),
),
scores_k,
)
scores = jax.vmap(handle_pairs)(input_ids, scores)
apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
apply_max_initial_timestamp = jnp.where(
self.max_initial_timestamp_index is not None,
True and apply_max_initial_timestamp,
False,
)
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores = jnp.where(
apply_max_initial_timestamp,
scores.at[:, last_allowed + 1 :].set(-float("inf")),
scores,
)
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = jax.nn.log_softmax(scores, axis=-1)
def handle_cumulative_probs(logprobs_k, scores_k):
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
return jnp.where(
timestamp_logprob > max_text_token_logprob,
scores_k.at[: self.timestamp_begin].set(-float("inf")),
scores_k,
)
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
return scores
......@@ -37,8 +37,11 @@ from .configuration_utils import GenerationConfig
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
......@@ -164,6 +167,50 @@ class FlaxGenerationMixin:
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
) -> jnp.ndarray:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
# Only use this arg if not None, otherwise just remove from model_kwargs
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
if decoder_input_ids is not None:
return decoder_input_ids
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
# retrieve decoder_start_token_id for encoder-decoder models
# fall back to bos_token_id if necessary
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod
def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
......@@ -224,6 +271,7 @@ class FlaxGenerationMixin:
prng_key: Optional[jnp.ndarray] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
**kwargs,
):
r"""
......@@ -244,6 +292,10 @@ class FlaxGenerationMixin:
considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
logits_processor (`FlaxLogitsProcessorList `, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
......@@ -277,6 +329,8 @@ class FlaxGenerationMixin:
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
# set init values
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
......@@ -306,12 +360,19 @@ class FlaxGenerationMixin:
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
batch_size = input_ids.shape[0]
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
......@@ -347,7 +408,11 @@ class FlaxGenerationMixin:
" increasing`max_new_tokens`."
)
logits_processor = self._get_logits_processor(generation_config=generation_config)
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
logits_processor=logits_processor,
)
if not generation_config.do_sample and generation_config.num_beams == 1:
return self._greedy_search(
......@@ -419,7 +484,12 @@ class FlaxGenerationMixin:
return warpers
def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
logits_processor: Optional[FlaxLogitsProcessorList],
) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head.
......@@ -440,9 +510,51 @@ class FlaxGenerationMixin:
processors.append(
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
if generation_config.suppress_tokens is not None:
processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
forced_decoder_ids = [
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors
def _merge_criteria_processor_list(
self,
default_list: FlaxLogitsProcessorList,
custom_list: FlaxLogitsProcessorList,
) -> FlaxLogitsProcessorList:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
def _greedy_search(
self,
input_ids: None,
......
......@@ -163,6 +163,7 @@ else:
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
......@@ -176,6 +177,7 @@ else:
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForSpeechSeq2Seq",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
......@@ -320,6 +322,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
......@@ -333,6 +336,7 @@ if TYPE_CHECKING:
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForSpeechSeq2Seq,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
......
......@@ -55,6 +55,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
......@@ -76,6 +77,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("whisper", "FlaxWhisperForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
......@@ -219,6 +221,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
("whisper", "FlaxWhisperForConditionalGeneration"),
]
)
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {
......@@ -50,6 +56,19 @@ else:
"TFWhisperPreTrainedModel",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_whisper"] = [
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor
......@@ -82,6 +101,18 @@ if TYPE_CHECKING:
TFWhisperPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax whisper model."""
import random
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
etc.) This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
and [`~FlaxPreTrainedModel.to_bf16`].
"""
WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
the starting token for `decoder_input_ids` generation.
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
spectrogram are ignored.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WHISPER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
encoder_outputs (`tuple(tuple(numpy.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class FlaxWhisperAttention(nn.Module):
config: WhisperConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj = dense(use_bias=self.bias)
self.k_proj = dense(use_bias=False)
self.v_proj = dense(use_bias=self.bias)
self.out_proj = dense(use_bias=self.bias)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool"
)
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
query_states = self.q_proj(hidden_states)
if is_cross_attention:
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask,
(0, 0, mask_shift, 0),
(1, 1, query_length, max_decoder_length),
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
def _split_heads(self, hidden_state) -> jnp.ndarray:
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_state) -> jnp.ndarray:
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only
# attend to those key positions that have already been generated and cached, not the
# remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper
class FlaxWhisperEncoderLayer(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper
class FlaxWhisperDecoderLayer(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv1 = nn.Conv(
self.config.d_model,
kernel_size=(3,),
padding=1,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.conv2 = nn.Conv(
self.config.d_model,
kernel_size=(3,),
strides=2,
padding=1,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_features: jnp.ndarray,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2):
raise ValueError(
"input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
)
input_features = input_features.transpose(0, 2, 1)
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
hidden_states = hidden_states + embed_positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask=None,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: jnp.ndarray,
position_ids: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
input_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
hidden_states = input_embeds + position_embeds
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
decoder_attention_mask: jnp.ndarray,
decoder_position_ids: jnp.ndarray,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
config_class = WhisperConfig
base_model_prefix: str = "model"
main_input_name = "input_features"
module_class: nn.Module = None
def __init__(
self,
config: WhisperConfig,
input_shape: Tuple[int] = (1, 80, 3000),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
def encode(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_features, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_features, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
batch_size, sequence_length = decoder_input_ids.shape
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxWhisperAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare decoder inputs
if decoder_position_ids is None:
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
"The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
module_class = FlaxWhisperModule
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_features,
decoder_input_ids,
decoder_attention_mask: jnp.ndarray = None,
decoder_position_ids: jnp.ndarray = None,
position_ids: jnp.ndarray = None,
attention_mask: jnp.ndarray = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
batch_size, sequence_length = decoder_input_ids.shape
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxWhisperAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
outputs = decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
return lm_logits, outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def generate(
self,
input_features,
generation_config=None,
logits_processor=None,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs
):
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
if language is not None:
generation_config.language = language
if kwargs is not None and "decoder_input_ids" in kwargs:
decoder_input_length = len(kwargs["decoder_input_ids"])
else:
decoder_input_length = 1
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
forced_decoder_ids.append((1, None))
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
return super().generate(
input_features,
generation_config,
logits_processor=logits_processor,
**kwargs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
Returns:
Transcription example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
overwrite_call_docstring(
FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
......@@ -162,6 +162,9 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
......@@ -241,6 +244,13 @@ class FlaxAutoModelForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxAutoModelForSpeechSeq2Seq(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxAutoModelForTokenClassification(metaclass=DummyObject):
_backends = ["flax"]
......@@ -1130,6 +1140,27 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWhisperModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWhisperPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXGLMForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import tempfile
import unittest
import transformers
from transformers import WhisperConfig, is_flax_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
if is_datasets_available():
import datasets
from datasets import load_dataset
if is_flax_available():
import numpy as np
import jax
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import (
FLAX_MODEL_MAPPING,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
WhisperFeatureExtractor,
WhisperProcessor,
)
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
@require_flax
class FlaxWhisperModelTester:
config_cls = WhisperConfig
config_updates = {}
hidden_act = "gelu"
def __init__(
self,
parent,
batch_size=13,
seq_length=60,
is_training=True,
use_labels=False,
vocab_size=99,
d_model=16,
decoder_attention_heads=4,
decoder_ffn_dim=16,
decoder_layers=2,
encoder_attention_heads=4,
encoder_ffn_dim=16,
encoder_layers=2,
input_channels=1,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=70,
max_source_positions=30,
max_target_positions=40,
bos_token_id=98,
eos_token_id=98,
pad_token_id=0,
num_mel_bins=80,
decoder_start_token_id=85,
num_conv_layers=1,
suppress_tokens=None,
begin_suppress_tokens=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = encoder_layers
self.num_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_seq_length = seq_length // 2
self.decoder_seq_length = 1
self.input_channels = input_channels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_mel_bins = num_mel_bins
self.max_position_embeddings = max_position_embeddings
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.decoder_start_token_id = decoder_start_token_id
self.num_conv_layers = num_conv_layers
self.suppress_tokens = suppress_tokens
self.begin_suppress_tokens = begin_suppress_tokens
def prepare_config_and_inputs_for_common(self):
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]])
config = WhisperConfig(
vocab_size=self.vocab_size,
num_mel_bins=self.num_mel_bins,
decoder_start_token_id=self.decoder_start_token_id,
is_encoder_decoder=True,
activation_function=self.hidden_act,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_source_positions=self.max_source_positions,
max_target_positions=self.max_target_positions,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
tie_word_embeddings=True,
d_model=self.d_model,
decoder_attention_heads=self.decoder_attention_heads,
decoder_ffn_dim=self.decoder_ffn_dim,
decoder_layers=self.decoder_layers,
encoder_attention_heads=self.encoder_attention_heads,
encoder_ffn_dim=self.encoder_ffn_dim,
encoder_layers=self.encoder_layers,
suppress_tokens=self.suppress_tokens,
begin_suppress_tokens=self.begin_suppress_tokens,
)
inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids)
return config, inputs_dict
def prepare_whisper_inputs_dict(
config,
input_ids,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = np.concatenate(
[
np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8),
np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8),
],
axis=-1,
)
return {
"input_features": input_ids,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_partial_class(full_class, *args, **kwargs):
partial_class = partialclass(full_class, *args, **kwargs)
partial_class.__name__ = full_class.__name__
partial_class.__module__ = full_class.__module__
return partial_class
@require_flax
class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else ()
all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = FlaxWhisperModelTester(self)
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
self.all_model_classes = (
make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
def test_config(self):
self.config_tester.run_common_tests()
# overwrite because of `input_features`
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_features", "decoder_input_ids"]
self.assertListEqual(arg_names[:2], expected_arg_names)
# overwrite because of `input_features`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_features, decoder_input_ids, **kwargs):
return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
def test_save_load_to_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow
@require_flax
class FlaxWhisperModelIntegrationTest(unittest.TestCase):
@cached_property
def default_processor(self):
return WhisperProcessor.from_pretrained("openai/whisper-base")
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_tiny_logits_librispeech(self):
model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
logits = model(
input_features,
decoder_input_ids=np.array([[50258, 50259, 50359]]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
)
# fmt: off
EXPECTED_LOGITS = np.array(
[
2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
]
)
# fmt: on
self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
def test_small_en_logits_librispeech(self):
model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
logits = model(
input_features,
decoder_input_ids=np.array([model.config.decoder_start_token_id]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
)
logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
# fmt: off
EXPECTED_LOGITS = np.array(
[
-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
-11.1146, -8.1918
]
)
# fmt: on
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
def test_large_logits_librispeech(self):
model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)
input_speech = self._load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"
)
input_features = processed_inputs.input_features
decoder_input_ids = processed_inputs.labels
logits = model(
input_features,
decoder_input_ids=decoder_input_ids,
output_hidden_states=False,
output_attentions=False,
return_dict=False,
)
logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
# fmt: off
EXPECTED_LOGITS = np.array(
[
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
]
)
# fmt: on
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
def test_tiny_en_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
model.config.decoder_start_token_id = 50257
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
).input_features
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
" classes and we are glad to"
)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
def test_tiny_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True)
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
).input_features
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
" classes and we are glad"
)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
def test_large_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
).input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
def test_large_generation_multilingual(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
).sequences
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura-san called me."
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
def test_large_batched_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
generated_ids = model.generate(input_features, max_length=20).sequences
# fmt: off
EXPECTED_LOGITS = np.array(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
)
# fmt: on
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes and we are glad to",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
def test_tiny_en_batched_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
generated_ids = model.generate(input_features, max_length=20).sequences
# fmt: off
EXPECTED_LOGITS = np.array(
[
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
]
)
# fmt: on
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes, and we are glad to",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef looming",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_timestamp_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features
generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))
generated_ids = generate_fn(input_features)
# fmt: off
EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])
# fmt: on
self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))
EXPECTED_TRANSCRIPT = [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"
" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"
" of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"
" its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"
" work is really Greek after all, and"
),
"offsets": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 6.5600000000000005),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.5600000000000005, 11.24),
},
{
"text": (
" He tells us that at this festive season of the year, with Christmas and roast beef"
" looming"
),
"timestamp": (11.24, 16.88),
},
{
"text": (
" before us, similarly drawn from eating and its results occur most readily to the mind."
),
"timestamp": (16.88, 23.76),
},
{
"text": (
" He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"
),
"timestamp": (23.76, 29.44),
},
],
}
]
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
......@@ -22,9 +22,10 @@ import unittest
import numpy as np
import transformers
from transformers import WhisperConfig
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property, is_flax_available, is_torch_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
......@@ -48,6 +49,13 @@ if is_torch_available():
)
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
def prepare_whisper_inputs_dict(
config,
......@@ -747,6 +755,159 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertTrue(models_equal)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
@require_torch
@require_torchaudio
......
......@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model.
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
"WhisperDecoder", # Building part of bigger (tested) model.
"WhisperEncoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment