"vscode:/vscode.git/clone" did not exist on "eef0507f3daca400f3021a25a8a5d399cea45338"
Unverified Commit e3342edc authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Flax Speech-Encoder-Decoder Model (#15613)



* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 935a76d9
......@@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
......
......@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
[[autodoc]] SpeechEncoderDecoderModel
- forward
- from_encoder_decoder_pretrained
## FlaxSpeechEncoderDecoderModel
[[autodoc]] FlaxSpeechEncoderDecoderModel
- __call__
- from_encoder_decoder_pretrained
\ No newline at end of file
......@@ -2295,6 +2295,7 @@ if is_flax_available():
"FlaxRoFormerPreTrainedModel",
]
)
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
......@@ -4183,6 +4184,7 @@ if TYPE_CHECKING:
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
......
......@@ -188,6 +188,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
]
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
......@@ -215,6 +221,9 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass):
......@@ -309,3 +318,12 @@ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
......@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
if is_torch_available():
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
if is_flax_available():
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
else:
import sys
......
......@@ -1012,6 +1012,26 @@ class FlaxWav2Vec2Module(nn.Module):
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
attention_mask = jnp.array(attention_mask, dtype=bool)
return attention_mask
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
......
......@@ -879,6 +879,13 @@ class FlaxRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
......
......@@ -215,6 +215,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
......@@ -290,6 +291,7 @@ def get_model_test_files():
"test_modeling_common",
"test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder",
"test_modeling_flax_speech_encoder_decoder",
"test_modeling_marian",
"test_modeling_tf_common",
"test_modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment