Unverified Commit e3342edc authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Flax Speech-Encoder-Decoder Model (#15613)



* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 935a76d9
......@@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
......
......@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
[[autodoc]] SpeechEncoderDecoderModel
- forward
- from_encoder_decoder_pretrained
## FlaxSpeechEncoderDecoderModel
[[autodoc]] FlaxSpeechEncoderDecoderModel
- __call__
- from_encoder_decoder_pretrained
\ No newline at end of file
......@@ -2295,6 +2295,7 @@ if is_flax_available():
"FlaxRoFormerPreTrainedModel",
]
)
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
......@@ -4183,6 +4184,7 @@ if TYPE_CHECKING:
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
......
......@@ -188,6 +188,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
]
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
......@@ -215,6 +221,9 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass):
......@@ -309,3 +318,12 @@ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
......@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
if is_torch_available():
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
if is_flax_available():
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
else:
import sys
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Flax Speech-Encoder-Decoder architectures"""
import os
from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
and should be fine-tuned on a downstream generative task, like summarization.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
translation yields a significant performance improvement.
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Parameters:
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
*torch.FloatTensor*.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.decoder.max_position_embeddings - 1]`.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
"""
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
*torch.FloatTensor*.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.FlaxBaseModelOutput`] instead of a plain tuple.
"""
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.decoder.max_position_embeddings - 1]`.
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.FlaxCausalLMOutputWithCrossAttentions`] instead of
a plain tuple.
"""
class FlaxSpeechEncoderDecoderModule(nn.Module):
config: SpeechEncoderDecoderConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
encoder_config = self.config.encoder
decoder_config = self.config.decoder
# Copied from `modeling_hybrid_clip.py` with modifications.
from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
dtype=self.dtype,
)
else:
self.enc_to_dec_proj = None
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_encoder_module(self):
return self.encoder
def _get_projection_module(self):
return self.enc_to_dec_proj
def _get_decoder_module(self):
return self.decoder
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder of the speech encoder in
order that its parameters are not updated during training.
"""
self.encoder.freeze_feature_encoder()
def __call__(
self,
inputs,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
encoder_outputs=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if self.enc_to_dec_proj is not None:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# compute correct encoder attention mask
if attention_mask is not None:
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
encoder_hidden_states.shape[1], attention_mask
)
else:
encoder_attention_mask = None
# flax script modeling_flax_wav2vec2.py
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqLMOutput(
logits=decoder_outputs.logits,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
r"""
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
"""
config_class = SpeechEncoderDecoderConfig
base_model_prefix: str = "speech_encoder_decoder"
module_class = FlaxSpeechEncoderDecoderModule
def __init__(
self,
config: SpeechEncoderDecoderConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if config.decoder.cross_attention_hidden_size is not None:
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
# speech encoders almost always downsample the sequence length dimension
encoder_input_length = 1024
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input DeviceArrays
inputs = jnp.zeros(encoder_input_shape, dtype="i4")
attention_mask = jnp.ones_like(inputs)
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = inputs.shape
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
inputs,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
)["params"]
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
return self.module._get_feat_extract_output_lengths(input_lengths)
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def encode(
self,
inputs: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import FlaxSpeechEncoderDecoderModel
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
... )
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
>>> encoder_outputs = model.encode(inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, inputs, attention_mask, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(inputs, attention_mask, **kwargs)
outputs = self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
if return_dict:
outputs = FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import FlaxSpeechEncoderDecoderModel
>>> import jax.numpy as jnp
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
... )
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
>>> encoder_outputs = model.encode(inputs)
>>> decoder_start_token_id = model.config.decoder.bos_token_id
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> logits = outputs.logits
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
params = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
params["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
):
projection_module = module._get_projection_module()
decoder_module = module._get_decoder_module()
# optionally project encoder_hidden_states
if projection_module is not None:
encoder_hidden_states = projection_module(encoder_hidden_states)
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
encoder_hidden_states,
**kwargs,
)
outputs = self.module.apply(
params,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
inputs: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
>>> # load a fine-tuned wav2vec2-2-bart model
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
>>> # load output tokenizer
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
>>> # use bart's special bos, pad and eos tokens
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
>>> model.config.pad_token_id = model.decoder.config.pad_token_id
>>> model.config.eos_token_id = model.decoder.config.eos_token_id
>>> outputs = model.generate(inputs)
# Assert something? More interesting input? dtype correct?
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
# prepare decoder inputs
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
*model_args,
**kwargs
) -> FlaxPreTrainedModel:
r"""
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
checkpoints.
Params:
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
Information necessary to initiate the encoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
Information necessary to initiate the decoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
model_args (remaining positional arguments, *optional*):
All remaning positional arguments will be passed to the underlying model's `__init__` method.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`).
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a `config` is provided or automatically loaded.
Example:
```python
>>> from transformers import FlaxSpeechEncoderDecoderModel
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./wav2vec2-2-bart-large")
>>> # load fine-tuned model
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
```"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = FlaxAutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
# init model
model = cls(config, dtype=dtype)
model.params["encoder"] = encoder.params
model.params["decoder"] = decoder.params
return model
......@@ -1012,6 +1012,26 @@ class FlaxWav2Vec2Module(nn.Module):
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
attention_mask = jnp.array(attention_mask, dtype=bool)
return attention_mask
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
......
......@@ -879,6 +879,13 @@ class FlaxRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
if is_flax_available():
from transformers import (
FlaxGPT2LMHeadModel,
FlaxSpeechEncoderDecoderModel,
FlaxWav2Vec2Model,
SpeechEncoderDecoderConfig,
)
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import SpeechEncoderDecoderModel
@require_flax
class FlaxEncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
def prepare_config_and_inputs(self):
raise NotImplementedError
def get_pretrained_model(self):
raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
encoder_outputs = FlaxBaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
outputs_encoder_decoder = enc_dec_model(
attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs=encoder_outputs
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_from_pretrained(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_save_and_load(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 4e-2)
def check_encoder_decoder_model_output_attentions(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
seq_len = enc_dec_model._get_feat_extract_output_lengths(inputs.shape[1])
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
)
def check_encoder_decoder_model_generate(self, inputs, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
pad_token_id = enc_dec_model.config.decoder.pad_token_id
eos_token_id = enc_dec_model.config.decoder.eos_token_id
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id
if decoder_start_token_id is None:
decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id
# Bert does not have a bos token id, so use pad_token_id instead
# Copied from `test_modeling_encoder_decoder.py`
if decoder_start_token_id is None:
decoder_start_token_id = pad_token_id
generated_output = enc_dec_model.generate(
inputs,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
)
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,))
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
decoder_config.hidden_size = config.hidden_size
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
inputs = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2(
inputs=inputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = FlaxSpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
after_outputs = model_1(
inputs=inputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 4e-2)
@require_flax
class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model_and_inputs(self):
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/wav2vec2-large-lv60", "gpt2-medium"
)
batch_size = 13
input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = FlaxWav2Vec2Model(config)
decoder_model = FlaxGPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
model_tester_decoder = FlaxGPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, inputs, attention_mask) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"inputs": inputs,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
"jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True
)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
......@@ -215,6 +215,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
......@@ -290,6 +291,7 @@ def get_model_test_files():
"test_modeling_common",
"test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder",
"test_modeling_flax_speech_encoder_decoder",
"test_modeling_marian",
"test_modeling_tf_common",
"test_modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment