Unverified Commit 0d1f67e6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add wav2vec2 (#12271)



* fix_torch_device_generate_test

* remove @

* start flax wav2vec2

* save intermediate

* forward pass has correct shape

* add weight norm

* add files

* finish ctc

* make style

* finish gumbel quantizer

* correct docstrings

* correct some more files

* fix vit

* finish quality

* correct tests

* correct docstring

* correct tests

* start wav2vec2 pretraining script

* save intermediate

* start pretraining script

* finalize pretraining script

* finish

* finish

* small typo

* finish

* correct

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* make style

* push
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 3f36a2c0
......@@ -411,7 +411,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | |
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -99,3 +99,23 @@ TFWav2Vec2ForCTC
.. autoclass:: transformers.TFWav2Vec2ForCTC
:members: call
FlaxWav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2Model
:members: __call__
FlaxWav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2ForCTC
:members: __call__
FlaxWav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining
:members: __call__
# Wav2Vec2 Contrastive Loss PreTraining examples
The following example showcases how to pretrain a wav2vec2 model using the JAX/Flax backend.
Pretraining Wav2Vec2 is rather complex, so it is highly recommended to read the
[official paper](https://arxiv.org/abs/2006.11477).
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
`run_wav2vec2_pretrain_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then pretrain the wav2vec2 architectures above on it.
For custom datasets in `jsonlines` format please see: [the Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) and you also will find examples of these below.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"wav2vec2-base-robust"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create wav2vec2-base-robust
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/wav2vec2-base-robust
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd wav2vec2-base-robust
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_wav2vec2_pretrain_flax`.
```bash
export MODEL_DIR="./wav2vec2-base-robust"
ln -s ~/transformers/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py ./
```
### Create the model configuration
Let's first create the model configuration and store it in the model repository.
Note that many training parameters can be set in the model configuration including
the configuration about the masking distribution (`mask_time_length`, `mask_time_prob`),
dropout (`attention_dropout`, ...), the trade-off between the contrastive loss and
the diversity loss, etc...
Mostly likely you will need to change these parameters depending on your use case.
Again, we highly recommend to read the [official paper](https://arxiv.org/abs/2006.11477)
to better understand which parameters can be set for pretraining.
For this example, we will be using a `"base"`-sized model of Wav2Vec2 with robust
layer norm and keep most of the default settings.
```python
model_dir="./wav2vec2-base-robust"
from transformers import Wav2Vec2Config
config = Wav2Vec2Config.from_pretrained(
"facebook/wav2vec2-base",
mask_time_length=10,
mask_time_prob=0.05,
diversity_loss_weight=0.1,
num_negatives=100,
do_stable_layer_norm=True,
feat_extract_norm="layer",
)
config.save_pretrained(model_dir)
```
### Create a feature extractor configuration
Before we can start the training, we need to define
a feature extractor that takes care of normalization, etc...
Here we can also re-use the feature extractor of [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base) while making sure that padding is allowed.
```python
model_dir="./wav2vec2-base-robust"
from transformers import Wav2Vec2FeatureExtractor
config = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base", return_attention_mask=True)
config.save_pretrained(model_dir)
```
### Train the model
Finally, we can run the example script to train the model:
```bash
./run_wav2vec2_pretrain_flax.py \
--output_dir=${MODEL_DIR} \
--num_train_epochs="5" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--learning_rate="5e-4" \
--weight_decay="0.01" \
--warmup_steps="2000" \
--model_name_or_path=${MODEL_DIR} \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="4" \
--max_duration_in_seconds="10.0" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--push_to_hub
```
Note that this script is not fully tested yet, so we cannot ensure that
the above script leads to satisfying results.
#!/usr/bin/env python3
import logging
import sys
import time
from dataclasses import field
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
from datasets import DatasetDict, load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import librosa
import optax
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
FlaxWav2Vec2ForPreTraining,
HfArgumentParser,
TrainingArguments,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
is_tensorboard_available,
)
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices
logger = logging.getLogger(__name__)
@flax.struct.dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
verbose_logging: Optional[bool] = field(
default=False,
metadata={"help": "Whether to log verbose messages or not."},
)
max_gumbel_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
)
min_gumbel_temperature: Optional[float] = field(
default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."}
)
gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
)
@flax.struct.dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
speech_file_column: Optional[str] = field(
default="file",
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_duration_in_seconds: Optional[float] = field(
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
)
pad_to_multiple_of: Optional[int] = field(
default=1024,
metadata={
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
},
)
@flax.struct.dataclass
class FlaxDataCollatorForWav2Vec2Pretraining:
"""
Data collator that will dynamically pad the inputs received and prepare masked indices
for self-supervised pretraining.
Args:
model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
The Wav2Vec2 model used for pretraining. The data collator needs to have access
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
model: FlaxWav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
max_length=self.max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="np",
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
min_masks=2,
)
# sample indices to take for negative vectors
batch["sampled_negative_indices"] = _sample_negative_indices(
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
self.model.config.num_negatives,
)
return batch
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logging_level = logging.WARNING
if model_args.verbose_logging:
logging_level = logging.DEBUG
logger.setLevel(logging_level)
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
def compute_contrastive_loss(
quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
):
batch_size, sequence_length, hidden_size = quantized_features.shape
# take negative vectors from sampled indices
quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
quantized_negatives = quantized_negatives.reshape(
batch_size, sequence_length, num_negatives, hidden_size
).transpose(2, 0, 1, 3)
target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
loss_logits = optax.cosine_similarity(transformer_features, target_features)
loss_logits = loss_logits / logits_temp
neg_is_pos = (quantized_features == quantized_negatives).all(-1)
neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
# make sure incorrectly sampled vectors don't contribute to loss
loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)
predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
target_mask = jnp.where(targets >= 0, 1.0, 0.0)
contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask
contrastive_loss = contrastive_loss.sum()
return contrastive_loss
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
configure_logger(model_args, training_args)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split="validation",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}",
cache_dir=model_args.cache_dir,
)
# only normalized-inputs-training is supported
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
)
def prepare_dataset(batch):
# check that all files have the correct sampling rate
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
return batch
# load audio files into numpy arrays
vectorized_datasets = datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
)
# filter audio files that are too long
vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
)
def normalize(batch):
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
# normalize and transform to `BatchFeatures`
vectorized_datasets = vectorized_datasets.map(
normalize,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=vectorized_datasets["train"].column_names,
)
# pretraining is only supported for "newer" stable layer norm architecture
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
config = Wav2Vec2Config.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
gradient_checkpointing=model_args.gradient_checkpointing,
)
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
model = FlaxWav2Vec2ForPreTraining(
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
gumbel_rngs = jax.random.split(rng, jax.local_device_count())
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state and define training hyper-parameters
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
num_negatives = model.config.num_negatives
contrastive_logits_temperature = model.config.contrastive_logits_temperature
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss_weight = model.config.diversity_loss_weight
# Define gradient update step fn
def train_step(state, batch, dropout_rng, gumbel_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)
def loss_fn(params):
negative_indices = batch.pop("sampled_negative_indices")
gumbel_temperature = jnp.clip(
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
a_min=model_args.min_gumbel_temperature,
)
outputs = state.apply_fn(
**batch,
gumbel_temperature=gumbel_temperature,
params=params,
dropout_rng=dropout_rng,
gumbel_rng=gumbel_rng,
train=True,
)
contrastive_loss = compute_contrastive_loss(
outputs.projected_quantized_states,
outputs.projected_states,
negative_indices,
batch["mask_time_indices"],
contrastive_logits_temperature,
num_negatives,
)
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
loss = contrastive_loss + diversity_loss_weight * diversity_loss
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
return new_state, metrics, new_dropout_rng, new_gumbel_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
negative_indices = batch.pop("sampled_negative_indices")
outputs = model(**batch, params=params, train=False)
contrastive_loss = compute_contrastive_loss(
outputs.projected_quantized_states,
outputs.projected_states,
negative_indices,
batch["mask_time_indices"],
contrastive_logits_temperature,
num_negatives,
)
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
loss = contrastive_loss + diversity_loss_weight * diversity_loss
# summarize metrics
metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(vectorized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
model_inputs = shard(model_inputs.data)
# Model forward
state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
state, model_inputs, dropout_rngs, gumbel_rngs
)
train_metrics.append(train_metric)
train_time += time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
# ======================== Evaluating ==============================
num_eval_samples = len(vectorized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Update progress bar
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
)
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
if __name__ == "__main__":
main()
......@@ -1643,6 +1643,9 @@ if is_flax_available():
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
else:
from .utils import dummy_flax_objects
......@@ -3023,6 +3026,12 @@ if TYPE_CHECKING:
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
......
......@@ -64,6 +64,7 @@ from ..roberta.modeling_flax_roberta import (
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import auto_class_factory
from .configuration_auto import (
BartConfig,
......@@ -75,6 +76,7 @@ from .configuration_auto import (
RobertaConfig,
T5Config,
ViTConfig,
Wav2Vec2Config,
)
......@@ -93,6 +95,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel),
(T5Config, FlaxT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
]
)
......@@ -105,6 +108,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
(T5Config, FlaxT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
]
)
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -37,7 +37,6 @@ if is_torch_available():
"Wav2Vec2PreTrainedModel",
]
if is_tf_available():
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -46,6 +45,14 @@ if is_tf_available():
"TFWav2Vec2PreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_wav2vec2"] = [
"FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining",
"FlaxWav2Vec2Model",
"FlaxWav2Vec2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
......@@ -71,6 +78,14 @@ if TYPE_CHECKING:
TFWav2Vec2PreTrainedModel,
)
if is_flax_available():
from .modeling_tf_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
else:
import importlib
......
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax Wav2Vec2 model. """
from functools import partial
from typing import Optional, Tuple, Union
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxWav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.FlaxWav2Vec2BaseModelOutput`, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, last_conv_dim)`):
Sequence of extracted feature vectors of the last convolutional layer of the model with ``last_conv_dim``
being the dimension of the last convolutional layer.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jnp.ndarray = None
extract_features: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.FlaxWav2Vec2ForPreTrainingOutput`, with potential hidden states and
attentions.
Args:
loss (`optional`, returned when model is in train mode, ``jnp.ndarray`` of shape :obj:`(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
projected_states (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked
projected quantized states.
projected_quantized_states (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive
target vectors for contrastive loss.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
projected_states: jnp.ndarray = None
projected_quantized_states: jnp.ndarray = None
codevector_perplexity: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
on CPU as part of the preprocessing during training.
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item())
num_masked_spans = max(num_masked_spans, min_masks)
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
# get random indices to mask
spec_aug_mask_idxs = np.array(
[
np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False)
for _ in range(batch_size)
]
)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length)
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape(
batch_size, num_masked_spans * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, hidden_size = features_shape
if sequence_length <= 1:
raise ValueError(
f"`features should have `sequence_length` > 1, but are of shape "
f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.random.randint(
low=0,
high=sequence_length - 1,
size=(batch_size, num_negatives * sequence_length),
)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
# correct for batch size
for batch_idx in range(1, batch_size):
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
<https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.Wav2Vec2Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
Args:
input_values (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Processor` should
be used for padding and conversion into a tensor of type `jnp.ndarray`. See
:meth:`transformers.Wav2Vec2Processor.__call__` for details.
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0,
1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ .. warning:: :obj:`attention_mask` should
only be passed if the corresponding processor has ``config.return_attention_mask == True``. For all models
whose processor has ``config.return_attention_mask == False``, such as `wav2vec2-base
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, :obj:`attention_mask` should **not** be passed to
avoid degraded performance when doing batched inference. For such models :obj:`input_values` should simply
be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield slightly
different results depending on whether :obj:`input_values` is padded or not.
mask_time_indices (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in `config.proj_codevector_dim` space.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
config: Wav2Vec2Config
layer_id: int = 0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
self.out_conv_dim = self.config.conv_dim[self.layer_id]
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id],
kernel_size=self.config.conv_kernel[self.layer_id],
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
padding="VALID",
dtype=self.dtype,
)
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.activation = ACT2FN[self.config.feat_extract_activation]
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxConvWithWeightNorm(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
features=self.config.hidden_size,
kernel_size=self.config.num_conv_pos_embeddings,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups,
dtype=self.dtype,
)
weight_shape = (
self.conv.features,
self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size,
)
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size // 2
def _get_normed_weights(self):
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
return normed_kernel
def __call__(self, hidden_states):
kernel = self._get_normed_weights()
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
return hidden_states
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
self.activation = ACT2FN[self.config.feat_extract_activation]
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
def __call__(self, hidden_states):
hidden_states = hidden_states.transpose((0, 1, 2))
hidden_states = self.conv(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose((0, 1, 2))
return hidden_states
class FlaxConvLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.config.feat_extract_norm == "layer":
self.layers = [
FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_feat_extract_layers)
]
elif self.config.feat_extract_norm == "group":
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
else:
raise ValueError(
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
def __call__(self, hidden_states):
for i, conv_layer in enumerate(self.layers):
hidden_states = conv_layer(hidden_states)
return hidden_states
class FlaxWav2Vec2FeatureExtractor(nn.Module):
"""Construct the featurs from raw audio waveform"""
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
def __call__(self, input_values):
hidden_states = input_values[:, :, None]
hidden_states = self.conv_layers(hidden_states)
return hidden_states
class FlaxWav2Vec2FeatureProjection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.projection = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
def __call__(self, hidden_states, deterministic=True):
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states, norm_hidden_states
class FlaxWav2Vec2Attention(nn.Module):
config: Wav2Vec2Config
embed_dim: int
num_heads: int
dropout: float = 0.0
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# get query proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
if attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class FlaxWav2Vec2FeedForward(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
self.intermediate_dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
if isinstance(self.config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
else:
self.intermediate_act_fn = self.config.hidden_act
self.output_dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxWav2Vec2Attention(
config=self.config,
embed_dim=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights = self.attention(
hidden_states, attention_mask=attention_mask, deterministic=deterministic
)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(
self.final_layer_norm(hidden_states), deterministic=deterministic
)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic=True,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states = jnp.where(
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = self.layer_norm(outputs[0])
if not return_dict:
return (hidden_states,) + outputs[1:]
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
<https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
"""
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.num_groups = self.config.num_codevector_groups
self.num_vars = self.config.num_codevectors_per_group
if self.config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {self.config.codevector_dim} must be divisible by"
f" `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
self.codevectors = self.param(
"codevectors",
jax.nn.initializers.uniform(),
(1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups),
)
self.weight_proj = nn.Dense(
self.num_groups * self.num_vars,
kernel_init=jax.nn.initializers.normal(1.0, self.dtype),
dtype=self.dtype,
)
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape)
probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs))
marginal_probs = probs.sum(axis=0) / mask.sum()
else:
marginal_probs = probs.mean(axis=0)
perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum()
return perplexity
def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1)
if not deterministic:
# sample code vector probs via gumbel in differentiateable way
gumbel_rng = self.make_rng("gumbel")
gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape)
codevector_probs = nn.softmax((hidden_states + gumbels) / temperature)
# compute perplexity
codevector_soft_dist = nn.softmax(
hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(axis=-1)
codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0
codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors
codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1)
return codevectors, perplexity
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2Config
base_model_prefix: str = "wav2vec2"
module_class: nn.Module = None
def __init__(
self,
config: Wav2Vec2Config,
input_shape: Tuple = (1, 1024),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_values = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_values)
params_rng, dropout_rng = jax.random.split(rng, 2)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_values.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
return self.module.apply(
inputs,
jnp.array(input_values, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
mask_time_indices,
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
return self.module._get_feat_extract_output_lengths(input_lengths)
class FlaxWav2Vec2Module(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.feature_extractor = FlaxWav2Vec2FeatureExtractor(self.config, dtype=self.dtype)
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
self.masked_spec_embed = self.param(
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
)
if self.config.do_stable_layer_norm:
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
extract_features = self.feature_extractor(input_values)
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = jax.ops.index_update(
attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1
)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = jnp.where(
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
hidden_states,
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return FlaxWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
)
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
class FlaxWav2Vec2ForCTCModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.final_dropout)
self.lm_head = nn.Dense(
self.config.vocab_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example::
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = jnp.argmax(logits, axis=-1)
>>> transcription = processor.decode(predicted_ids[0])
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
"""
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.lm_head(hidden_states)
if not return_dict:
return (logits,) + outputs[2:]
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
@add_start_docstrings(
"Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).",
WAV_2_VEC_2_START_DOCSTRING,
)
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForCTCModule
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout)
self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
self.project_q = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.project_hid = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
gumbel_temperature: int = 1,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example::
>>> import optax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = optax.cosine_similarity(
... outputs.projected_states, outputs.projected_quantized_states, axis=-1
... )
>>> # show that cosine similarity is much higher than random
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
return_dict=return_dict,
)
# project all transformed features (including masked) to final vq dim
transformer_features = self.project_hid(outputs[0])
# quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1], deterministic=deterministic)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature
)
quantized_features = self.project_q(quantized_features)
if not return_dict:
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return FlaxWav2Vec2ForPreTrainingOutput(
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForPreTrainingModule
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
# overwrite since has `gumbel_temperature` input
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
gumbel_temperature: int = 1,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
gumbel_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_values.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
if gumbel_rng is not None:
rngs["gumbel"] = gumbel_rng
inputs = {"params": params or self.params}
return self.module.apply(
inputs,
jnp.array(input_values, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
mask_time_indices,
gumbel_temperature,
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
......@@ -654,3 +654,31 @@ class FlaxViTPreTrainedModel:
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxWav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWav2Vec2ForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxWav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxWav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import unittest
import numpy as np
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
import optax
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
FlaxWav2Vec2GumbelVectorQuantizer,
FlaxWav2Vec2Model,
_compute_mask_indices,
_sample_negative_indices,
)
class FlaxWav2Vec2ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024, # speech is longer
is_training=False,
hidden_size=24,
feat_extract_norm="layer",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(32, 32, 32),
conv_stride=(4, 4, 4),
conv_kernel=(8, 8, 8),
conv_bias=False,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=4,
num_attention_heads=2,
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = conv_dim
self.conv_stride = conv_stride
self.conv_kernel = conv_kernel
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.scope = scope
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
output_seq_length = (output_seq_length - (kernel - 1)) / stride
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config(
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
feat_extract_activation=self.feat_extract_activation,
conv_dim=self.conv_dim,
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
)
return config, input_values, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_values, attention_mask = config_and_inputs
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
)
def setUp(self):
self.model_tester = FlaxWav2Vec2ModelTester(self)
def test_train(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config)
features_shape = (
input_values.shape[0],
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
)
batch_size, sequence_length = features_shape[:2]
mask_prob = 0.5
mask_length = 4
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
output = model(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
train=True,
dropout_rng=dropout_rng,
gumbel_rng=gumbel_rng,
)[0]
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
# overwrite because of `input_values`
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
@slow
# overwrite because of `input_values`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_values, attention_mask=None, **kwargs):
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_perplexity(self):
probs = np.arange(100).reshape(2, 5, 10) / 100
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = np.ones((2,), dtype=np.bool)
mask[0] = 0
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
sequence_length, hidden_size
) # each value in vector consits of same value
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
negative_indices = _sample_negative_indices(features.shape, num_negatives)
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
# take negative vectors from sampled indices
sampled_negatives = features[negative_indices.reshape(-1)]
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
2, 0, 1, 3
)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_flax
@require_datasets
@require_soundfile
@slow
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
import soundfile as sf
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
# map files to raw
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
return ds["speech"][:num_samples]
def test_inference_ctc_robust_batched(self):
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = jnp.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_pretrained(self):
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-lv60", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
min_masks=2,
)
outputs = model(
inputs_dict.input_values,
attention_mask=inputs_dict.attention_mask,
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = optax.cosine_similarity(
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
model_rand = FlaxWav2Vec2ForPreTraining(config)
outputs_rand = model_rand(
inputs_dict.input_values,
attention_mask=inputs_dict.attention_mask,
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = optax.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2 model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
......@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"CLIPVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation",
"DPRReader",
"FlaubertForQuestionAnswering",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment