"tests/vscode:/vscode.git/clone" did not exist on "e9636216496240e0e0174631de46d107562e9215"
Unverified Commit 2840272c authored by Andy Ehrenberg's avatar Andy Ehrenberg Committed by GitHub
Browse files

add flax whisper implementation (#20479)



* add flax whisper implementation

* rever change to setup

* remove unused imports

* revert generation changes

* flax whisper docs

* docs

* import order

* import sorting

* isort

* add dummy objects

* doc formatting

* formatting

* remove trailing whitespaces

* fix flax whisper docs

* add generation logic to unlock flax whisper

* remove scans

* give credits to Flax Bart implementation

* remove unused imports

* add license

* remove assert

* more credits to Bart

* fix style

* formatting

* support left padding

* add flax whisper generation test

* remove copied from comments whenever not a full copy

* fix docstrings for logits processors

* revert change to FlaxForceTokensLogitsProcessor

* revert doc changes

* improve generation docs

* reorganize

* formatting

* cleanup docs

* add tests

* handle empty list case

* fix forced decoder ids in flax tests

* add flax whisper to inits

* upate dummy objects

* docs for FlaxAutoModelForSpeechSeq2Seq

* fix decoder_position_ids computation in pretrained model decode/__call__ fns

* add Copied from statements as necessary

* compute position_ids only in __call__ and decode methods of pretrained model subclasses

* improve readabilityof compute positional embeddings

* check dimensionality of input_features instead of hidden_states

* copied from statement for init_cache

* formatting

* fix copies

* fix copies

* pass attention mask to encoder layers

* fix decoder module outputs

* set dtype
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* smaller flax model for whisper test

* Update src/transformers/generation/flax_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/models/whisper/test_modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cleanup
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* bias cleanup

* doc fix

* align style for force tokens processor

* readability

* fix input shape in tests

* revert FlaxGenerationMixin docstring

* formatting

* fix tests

* fix imports

* consistent encoder hidden states

* consistent hidden states

* input shapes

* typo

* partial class trick

* partial class for input shape

* base_class with correct input shape

* partial base classes

* match by name

* set main_input_name

* compare on names

* formatting

* remove unused import

* safer position ids computation

* safer position id computation

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove identical inherited tests

* fix prompt ids in tests

* use generation config

* use jnp array

* better var names

* more explicit bias use

* import transformers

* formatting

* test formatting

* remove unused imports

* remove unused imports

* formatting

* isort

* docs

* fix ln orders for encoder hidden states

* whisper unique generation stuff

* flake

* use finfo for attention bias

* docs

* Update src/transformers/generation/flax_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* docs

* add timestamp flax test

* jit for timestamps

* formatting

* clean up timestamps processor

* formatting

* remove if_true

* cleanup

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 7735e040
...@@ -402,7 +402,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -402,7 +402,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | ❌ | ✅ | ✅ | | | Whisper | ✅ | ❌ | ✅ | ✅ | |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ | | X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ | | X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ | | XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks. ...@@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks.
[[autodoc]] TFAutoModelForSpeechSeq2Seq [[autodoc]] TFAutoModelForSpeechSeq2Seq
### FlaxAutoModelForSpeechSeq2Seq
[[autodoc]] FlaxAutoModelForSpeechSeq2Seq
### AutoModelForAudioXVector ### AutoModelForAudioXVector
[[autodoc]] AutoModelForAudioXVector [[autodoc]] AutoModelForAudioXVector
......
...@@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] TFWhisperForConditionalGeneration [[autodoc]] TFWhisperForConditionalGeneration
- call - call
## FlaxWhisperModel
[[autodoc]] FlaxWhisperModel
- __call__
## FlaxWhisperForConditionalGeneration
[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__
...@@ -3382,6 +3382,7 @@ else: ...@@ -3382,6 +3382,7 @@ else:
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
...@@ -3395,6 +3396,7 @@ else: ...@@ -3395,6 +3396,7 @@ else:
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForSpeechSeq2Seq",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq", "FlaxAutoModelForVision2Seq",
] ]
...@@ -3578,6 +3580,13 @@ else: ...@@ -3578,6 +3580,13 @@ else:
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
) )
_import_structure["models.whisper"].extend(
[
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
)
_import_structure["models.xglm"].extend( _import_structure["models.xglm"].extend(
[ [
"FlaxXGLMForCausalLM", "FlaxXGLMForCausalLM",
...@@ -6381,6 +6390,7 @@ if TYPE_CHECKING: ...@@ -6381,6 +6390,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
...@@ -6394,6 +6404,7 @@ if TYPE_CHECKING: ...@@ -6394,6 +6404,7 @@ if TYPE_CHECKING:
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForSpeechSeq2Seq,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq, FlaxAutoModelForVision2Seq,
) )
...@@ -6529,6 +6540,7 @@ if TYPE_CHECKING: ...@@ -6529,6 +6540,7 @@ if TYPE_CHECKING:
FlaxWav2Vec2Model, FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel, FlaxWav2Vec2PreTrainedModel,
) )
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import ( from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -264,3 +264,191 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): ...@@ -264,3 +264,191 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
return scores return scores
class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
`begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
begining of the generation.
Args:
begin_suppress_tokens (`List[int]`):
Tokens to not sample.
begin_index (`int`):
Index where the tokens are suppressed.
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids, scores, cur_len: int):
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
return scores
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
to be `-inf` so they are not sampled.
Args:
suppress_tokens (`list`):
Tokens to not sample.
"""
def __init__(self, suppress_tokens: list):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
return scores
class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
to `-inf` so that they are sampled at their corresponding index.
Args:
force_token_map (`list`):
Map giving token ids and indices where they will be forced to be sampled.
"""
def __init__(self, force_token_map):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
# Indexes without forced tokens will have a negative value.
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
for index, token in force_token_map.items():
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
return new_scores
scores = lax.cond(
cur_len >= self.force_token_array.shape[0],
# If the current length is geq than the length of force_token_array, the processor does nothing.
lambda: scores,
# Otherwise, it may force a certain token.
lambda: lax.cond(
self.force_token_array[cur_len] >= 0,
# Only valid (positive) tokens are forced
lambda: _force_token(cur_len),
# Otherwise, the processor does nothing.
lambda: scores,
),
)
return scores
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.
Args:
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""
def __init__(self, generate_config, model_config, decoder_input_length):
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = decoder_input_length + 1
if generate_config.is_multilingual:
# room for language token and task token
self.begin_index += 2
if hasattr(generate_config, "max_initial_timestamp_index"):
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
else:
self.max_initial_timestamp_index = model_config.vocab_size
if self.max_initial_timestamp_index is None:
self.max_initial_timestamp_index = model_config.vocab_size
def __call__(self, input_ids, scores, cur_len):
# suppress <|notimestamps|> which is handled by without_timestamps
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
def handle_pairs(input_ids_k, scores_k):
last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
last_was_timestamp = jnp.where(
input_ids_k[cur_len - 1] >= self.timestamp_begin,
True and last_was_timestamp,
False,
)
penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
penultimate_was_timestamp = jnp.where(
input_ids_k[cur_len - 2] >= self.timestamp_begin,
True,
penultimate_was_timestamp,
)
return jnp.where(
last_was_timestamp,
jnp.where(
penultimate_was_timestamp > 0,
scores_k.at[self.timestamp_begin :].set(-float("inf")),
scores_k.at[: self.eos_token_id].set(-float("inf")),
),
scores_k,
)
scores = jax.vmap(handle_pairs)(input_ids, scores)
apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
apply_max_initial_timestamp = jnp.where(
self.max_initial_timestamp_index is not None,
True and apply_max_initial_timestamp,
False,
)
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores = jnp.where(
apply_max_initial_timestamp,
scores.at[:, last_allowed + 1 :].set(-float("inf")),
scores,
)
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = jax.nn.log_softmax(scores, axis=-1)
def handle_cumulative_probs(logprobs_k, scores_k):
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
return jnp.where(
timestamp_logprob > max_text_token_logprob,
scores_k.at[: self.timestamp_begin].set(-float("inf")),
scores_k,
)
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
return scores
...@@ -37,8 +37,11 @@ from .configuration_utils import GenerationConfig ...@@ -37,8 +37,11 @@ from .configuration_utils import GenerationConfig
from .flax_logits_process import ( from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper, FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
...@@ -164,6 +167,50 @@ class FlaxGenerationMixin: ...@@ -164,6 +167,50 @@ class FlaxGenerationMixin:
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
) -> jnp.ndarray:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
# Only use this arg if not None, otherwise just remove from model_kwargs
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
if decoder_input_ids is not None:
return decoder_input_ids
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
# retrieve decoder_start_token_id for encoder-decoder models
# fall back to bos_token_id if necessary
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod @staticmethod
def _expand_to_num_beams(tensor, num_beams): def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
...@@ -224,6 +271,7 @@ class FlaxGenerationMixin: ...@@ -224,6 +271,7 @@ class FlaxGenerationMixin:
prng_key: Optional[jnp.ndarray] = None, prng_key: Optional[jnp.ndarray] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -244,6 +292,10 @@ class FlaxGenerationMixin: ...@@ -244,6 +292,10 @@ class FlaxGenerationMixin:
considerably slower runtime. considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*): params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation. Optionally the model parameters can be passed. Can be useful for parallelized generation.
logits_processor (`FlaxLogitsProcessorList `, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
kwargs: kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
...@@ -277,6 +329,8 @@ class FlaxGenerationMixin: ...@@ -277,6 +329,8 @@ class FlaxGenerationMixin:
generation_config.validate() generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
# set init values # set init values
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
...@@ -306,12 +360,19 @@ class FlaxGenerationMixin: ...@@ -306,12 +360,19 @@ class FlaxGenerationMixin:
"generation results, please set `padding_side='left'` when initializing the tokenizer." "generation results, please set `padding_side='left'` when initializing the tokenizer."
) )
batch_size = input_ids.shape[0]
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs # add encoder_outputs to model_kwargs
if model_kwargs.get("encoder_outputs") is None: if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation # prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
# Prepare `max_length` depending on other stopping criteria. # Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
...@@ -347,7 +408,11 @@ class FlaxGenerationMixin: ...@@ -347,7 +408,11 @@ class FlaxGenerationMixin:
" increasing`max_new_tokens`." " increasing`max_new_tokens`."
) )
logits_processor = self._get_logits_processor(generation_config=generation_config) logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
logits_processor=logits_processor,
)
if not generation_config.do_sample and generation_config.num_beams == 1: if not generation_config.do_sample and generation_config.num_beams == 1:
return self._greedy_search( return self._greedy_search(
...@@ -419,7 +484,12 @@ class FlaxGenerationMixin: ...@@ -419,7 +484,12 @@ class FlaxGenerationMixin:
return warpers return warpers
def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
logits_processor: Optional[FlaxLogitsProcessorList],
) -> FlaxLogitsProcessorList:
""" """
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head. instances used to modify the scores of the language model head.
...@@ -440,9 +510,51 @@ class FlaxGenerationMixin: ...@@ -440,9 +510,51 @@ class FlaxGenerationMixin:
processors.append( processors.append(
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
) )
if generation_config.suppress_tokens is not None:
processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
forced_decoder_ids = [
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors return processors
def _merge_criteria_processor_list(
self,
default_list: FlaxLogitsProcessorList,
custom_list: FlaxLogitsProcessorList,
) -> FlaxLogitsProcessorList:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
def _greedy_search( def _greedy_search(
self, self,
input_ids: None, input_ids: None,
......
...@@ -163,6 +163,7 @@ else: ...@@ -163,6 +163,7 @@ else:
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
...@@ -176,6 +177,7 @@ else: ...@@ -176,6 +177,7 @@ else:
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForSpeechSeq2Seq",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq", "FlaxAutoModelForVision2Seq",
] ]
...@@ -320,6 +322,7 @@ if TYPE_CHECKING: ...@@ -320,6 +322,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
...@@ -333,6 +336,7 @@ if TYPE_CHECKING: ...@@ -333,6 +336,7 @@ if TYPE_CHECKING:
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForSpeechSeq2Seq,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq, FlaxAutoModelForVision2Seq,
) )
......
...@@ -55,6 +55,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -55,6 +55,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"), ("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"), ("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"), ("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"), ("xlm-roberta", "FlaxXLMRobertaModel"),
] ]
...@@ -76,6 +77,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -76,6 +77,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("roformer", "FlaxRoFormerForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("whisper", "FlaxWhisperForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
] ]
) )
...@@ -219,6 +221,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( ...@@ -219,6 +221,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[ [
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
("whisper", "FlaxWhisperForConditionalGeneration"),
] ]
) )
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -50,6 +56,19 @@ else: ...@@ -50,6 +56,19 @@ else:
"TFWhisperPreTrainedModel", "TFWhisperPreTrainedModel",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_whisper"] = [
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor from .feature_extraction_whisper import WhisperFeatureExtractor
...@@ -82,6 +101,18 @@ if TYPE_CHECKING: ...@@ -82,6 +101,18 @@ if TYPE_CHECKING:
TFWhisperPreTrainedModel, TFWhisperPreTrainedModel,
) )
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -162,6 +162,9 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None ...@@ -162,6 +162,9 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
...@@ -241,6 +244,13 @@ class FlaxAutoModelForSequenceClassification(metaclass=DummyObject): ...@@ -241,6 +244,13 @@ class FlaxAutoModelForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxAutoModelForSpeechSeq2Seq(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxAutoModelForTokenClassification(metaclass=DummyObject): class FlaxAutoModelForTokenClassification(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -1130,6 +1140,27 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject): ...@@ -1130,6 +1140,27 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWhisperModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWhisperPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXGLMForCausalLM(metaclass=DummyObject): class FlaxXGLMForCausalLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
This diff is collapsed.
...@@ -22,9 +22,10 @@ import unittest ...@@ -22,9 +22,10 @@ import unittest
import numpy as np import numpy as np
import transformers
from transformers import WhisperConfig from transformers import WhisperConfig
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property, is_flax_available, is_torch_available
from transformers.utils.import_utils import is_datasets_available from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -48,6 +49,13 @@ if is_torch_available(): ...@@ -48,6 +49,13 @@ if is_torch_available():
) )
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
def prepare_whisper_inputs_dict( def prepare_whisper_inputs_dict(
config, config,
...@@ -747,6 +755,159 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -747,6 +755,159 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertTrue(models_equal) self.assertTrue(models_equal)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
@require_torch @require_torch
@require_torchaudio @require_torchaudio
......
...@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model. "DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model.
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
"WhisperDecoder", # Building part of bigger (tested) model. "WhisperDecoder", # Building part of bigger (tested) model.
"WhisperEncoder", # Building part of bigger (tested) model. "WhisperEncoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment