Unverified Commit 0d1f67e6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add wav2vec2 (#12271)



* fix_torch_device_generate_test

* remove @

* start flax wav2vec2

* save intermediate

* forward pass has correct shape

* add weight norm

* add files

* finish ctc

* make style

* finish gumbel quantizer

* correct docstrings

* correct some more files

* fix vit

* finish quality

* correct tests

* correct docstring

* correct tests

* start wav2vec2 pretraining script

* save intermediate

* start pretraining script

* finalize pretraining script

* finish

* finish

* small typo

* finish

* correct

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* make style

* push
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 3f36a2c0
...@@ -411,7 +411,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -411,7 +411,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -99,3 +99,23 @@ TFWav2Vec2ForCTC ...@@ -99,3 +99,23 @@ TFWav2Vec2ForCTC
.. autoclass:: transformers.TFWav2Vec2ForCTC .. autoclass:: transformers.TFWav2Vec2ForCTC
:members: call :members: call
FlaxWav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2Model
:members: __call__
FlaxWav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2ForCTC
:members: __call__
FlaxWav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining
:members: __call__
# Wav2Vec2 Contrastive Loss PreTraining examples
The following example showcases how to pretrain a wav2vec2 model using the JAX/Flax backend.
Pretraining Wav2Vec2 is rather complex, so it is highly recommended to read the
[official paper](https://arxiv.org/abs/2006.11477).
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
`run_wav2vec2_pretrain_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then pretrain the wav2vec2 architectures above on it.
For custom datasets in `jsonlines` format please see: [the Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) and you also will find examples of these below.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"wav2vec2-base-robust"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create wav2vec2-base-robust
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/wav2vec2-base-robust
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd wav2vec2-base-robust
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_wav2vec2_pretrain_flax`.
```bash
export MODEL_DIR="./wav2vec2-base-robust"
ln -s ~/transformers/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py ./
```
### Create the model configuration
Let's first create the model configuration and store it in the model repository.
Note that many training parameters can be set in the model configuration including
the configuration about the masking distribution (`mask_time_length`, `mask_time_prob`),
dropout (`attention_dropout`, ...), the trade-off between the contrastive loss and
the diversity loss, etc...
Mostly likely you will need to change these parameters depending on your use case.
Again, we highly recommend to read the [official paper](https://arxiv.org/abs/2006.11477)
to better understand which parameters can be set for pretraining.
For this example, we will be using a `"base"`-sized model of Wav2Vec2 with robust
layer norm and keep most of the default settings.
```python
model_dir="./wav2vec2-base-robust"
from transformers import Wav2Vec2Config
config = Wav2Vec2Config.from_pretrained(
"facebook/wav2vec2-base",
mask_time_length=10,
mask_time_prob=0.05,
diversity_loss_weight=0.1,
num_negatives=100,
do_stable_layer_norm=True,
feat_extract_norm="layer",
)
config.save_pretrained(model_dir)
```
### Create a feature extractor configuration
Before we can start the training, we need to define
a feature extractor that takes care of normalization, etc...
Here we can also re-use the feature extractor of [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base) while making sure that padding is allowed.
```python
model_dir="./wav2vec2-base-robust"
from transformers import Wav2Vec2FeatureExtractor
config = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base", return_attention_mask=True)
config.save_pretrained(model_dir)
```
### Train the model
Finally, we can run the example script to train the model:
```bash
./run_wav2vec2_pretrain_flax.py \
--output_dir=${MODEL_DIR} \
--num_train_epochs="5" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--learning_rate="5e-4" \
--weight_decay="0.01" \
--warmup_steps="2000" \
--model_name_or_path=${MODEL_DIR} \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="4" \
--max_duration_in_seconds="10.0" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--push_to_hub
```
Note that this script is not fully tested yet, so we cannot ensure that
the above script leads to satisfying results.
...@@ -1643,6 +1643,9 @@ if is_flax_available(): ...@@ -1643,6 +1643,9 @@ if is_flax_available():
) )
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
else: else:
from .utils import dummy_flax_objects from .utils import dummy_flax_objects
...@@ -3023,6 +3026,12 @@ if TYPE_CHECKING: ...@@ -3023,6 +3026,12 @@ if TYPE_CHECKING:
) )
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.
......
...@@ -64,6 +64,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -64,6 +64,7 @@ from ..roberta.modeling_flax_roberta import (
) )
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import ( from .configuration_auto import (
BartConfig, BartConfig,
...@@ -75,6 +76,7 @@ from .configuration_auto import ( ...@@ -75,6 +76,7 @@ from .configuration_auto import (
RobertaConfig, RobertaConfig,
T5Config, T5Config,
ViTConfig, ViTConfig,
Wav2Vec2Config,
) )
...@@ -93,6 +95,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -93,6 +95,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(CLIPConfig, FlaxCLIPModel), (CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel), (ViTConfig, FlaxViTModel),
(T5Config, FlaxT5Model), (T5Config, FlaxT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
] ]
) )
...@@ -105,6 +108,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -105,6 +108,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(BartConfig, FlaxBartForConditionalGeneration), (BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining), (ElectraConfig, FlaxElectraForPreTraining),
(T5Config, FlaxT5ForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
] ]
) )
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -37,7 +37,6 @@ if is_torch_available(): ...@@ -37,7 +37,6 @@ if is_torch_available():
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_wav2vec2"] = [ _import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -46,6 +45,14 @@ if is_tf_available(): ...@@ -46,6 +45,14 @@ if is_tf_available():
"TFWav2Vec2PreTrainedModel", "TFWav2Vec2PreTrainedModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_wav2vec2"] = [
"FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining",
"FlaxWav2Vec2Model",
"FlaxWav2Vec2PreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
...@@ -71,6 +78,14 @@ if TYPE_CHECKING: ...@@ -71,6 +78,14 @@ if TYPE_CHECKING:
TFWav2Vec2PreTrainedModel, TFWav2Vec2PreTrainedModel,
) )
if is_flax_available():
from .modeling_tf_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
else: else:
import importlib import importlib
......
This diff is collapsed.
...@@ -654,3 +654,31 @@ class FlaxViTPreTrainedModel: ...@@ -654,3 +654,31 @@ class FlaxViTPreTrainedModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxWav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWav2Vec2ForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxWav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import unittest
import numpy as np
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
import optax
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2GumbelVectorQuantizer,
FlaxWav2Vec2Model,
_compute_mask_indices,
_sample_negative_indices,
)
class FlaxWav2Vec2ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024, # speech is longer
is_training=False,
hidden_size=24,
feat_extract_norm="layer",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(32, 32, 32),
conv_stride=(4, 4, 4),
conv_kernel=(8, 8, 8),
conv_bias=False,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=4,
num_attention_heads=2,
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = conv_dim
self.conv_stride = conv_stride
self.conv_kernel = conv_kernel
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.scope = scope
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
output_seq_length = (output_seq_length - (kernel - 1)) / stride
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config(
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
feat_extract_activation=self.feat_extract_activation,
conv_dim=self.conv_dim,
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
)
return config, input_values, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_values, attention_mask = config_and_inputs
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
)
def setUp(self):
self.model_tester = FlaxWav2Vec2ModelTester(self)
def test_train(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config)
features_shape = (
input_values.shape[0],
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
)
batch_size, sequence_length = features_shape[:2]
mask_prob = 0.5
mask_length = 4
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
output = model(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
train=True,
dropout_rng=dropout_rng,
gumbel_rng=gumbel_rng,
)[0]
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
# overwrite because of `input_values`
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
@slow
# overwrite because of `input_values`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_values, attention_mask=None, **kwargs):
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_perplexity(self):
probs = np.arange(100).reshape(2, 5, 10) / 100
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = np.ones((2,), dtype=np.bool)
mask[0] = 0
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
sequence_length, hidden_size
) # each value in vector consits of same value
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
negative_indices = _sample_negative_indices(features.shape, num_negatives)
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
# take negative vectors from sampled indices
sampled_negatives = features[negative_indices.reshape(-1)]
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
2, 0, 1, 3
)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_flax
@require_datasets
@require_soundfile
@slow
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
import soundfile as sf
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
# map files to raw
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
return ds["speech"][:num_samples]
def test_inference_ctc_robust_batched(self):
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = jnp.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_pretrained(self):
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-lv60", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
min_masks=2,
)
outputs = model(
inputs_dict.input_values,
attention_mask=inputs_dict.attention_mask,
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = optax.cosine_similarity(
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
model_rand = FlaxWav2Vec2ForPreTraining(config)
outputs_rand = model_rand(
inputs_dict.input_values,
attention_mask=inputs_dict.attention_mask,
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = optax.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2 model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
...@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"CLIPVisionModel", "CLIPVisionModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation", "DetrForSegmentation",
"DPRReader", "DPRReader",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment