Unverified Commit e3342edc authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Flax Speech-Encoder-Decoder Model (#15613)



* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 935a76d9
...@@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | | | Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ | | Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ | | Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ | | Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se ...@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
[[autodoc]] SpeechEncoderDecoderModel [[autodoc]] SpeechEncoderDecoderModel
- forward - forward
- from_encoder_decoder_pretrained - from_encoder_decoder_pretrained
## FlaxSpeechEncoderDecoderModel
[[autodoc]] FlaxSpeechEncoderDecoderModel
- __call__
- from_encoder_decoder_pretrained
\ No newline at end of file
...@@ -2295,6 +2295,7 @@ if is_flax_available(): ...@@ -2295,6 +2295,7 @@ if is_flax_available():
"FlaxRoFormerPreTrainedModel", "FlaxRoFormerPreTrainedModel",
] ]
) )
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
...@@ -4183,6 +4184,7 @@ if TYPE_CHECKING: ...@@ -4183,6 +4184,7 @@ if TYPE_CHECKING:
FlaxRoFormerModel, FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel, FlaxRoFormerPreTrainedModel,
) )
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
......
...@@ -188,6 +188,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( ...@@ -188,6 +188,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
...@@ -215,6 +221,9 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( ...@@ -215,6 +221,9 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
) )
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass): class FlaxAutoModel(_BaseAutoModelClass):
...@@ -309,3 +318,12 @@ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass): ...@@ -309,3 +318,12 @@ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling") FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -28,12 +28,18 @@ _import_structure = { ...@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"] _import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
if is_torch_available(): if is_torch_available():
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
if is_flax_available():
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
else: else:
import sys import sys
......
...@@ -1012,6 +1012,26 @@ class FlaxWav2Vec2Module(nn.Module): ...@@ -1012,6 +1012,26 @@ class FlaxWav2Vec2Module(nn.Module):
return input_lengths return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
attention_mask = jnp.array(attention_mask, dtype=bool)
return attention_mask
@add_start_docstrings( @add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
......
...@@ -879,6 +879,13 @@ class FlaxRoFormerPreTrainedModel(metaclass=DummyObject): ...@@ -879,6 +879,13 @@ class FlaxRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration(metaclass=DummyObject): class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -215,6 +215,7 @@ def get_model_modules(): ...@@ -215,6 +215,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder", "modeling_flax_encoder_decoder",
"modeling_flax_utils", "modeling_flax_utils",
"modeling_speech_encoder_decoder", "modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder", "modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities", "modeling_transfo_xl_utilities",
"modeling_tf_auto", "modeling_tf_auto",
...@@ -290,6 +291,7 @@ def get_model_test_files(): ...@@ -290,6 +291,7 @@ def get_model_test_files():
"test_modeling_common", "test_modeling_common",
"test_modeling_encoder_decoder", "test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder", "test_modeling_flax_encoder_decoder",
"test_modeling_flax_speech_encoder_decoder",
"test_modeling_marian", "test_modeling_marian",
"test_modeling_tf_common", "test_modeling_tf_common",
"test_modeling_tf_encoder_decoder", "test_modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment