"test/vscode:/vscode.git/clone" did not exist on "e4b6133b785a48918e326d79e26b61a33036b84d"
Unverified Commit d2cdefb9 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Add new meta w2v2-conformer BERT-like model (#28165)



* first commit

* correct default value non causal

* update config and modeling code

* update converting checkpoint

* clean modeling and fix tests

* make style

* add new config parameters to docstring

* fix copied from statements

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* make position_embeddings_type docstrings clearer

* clean converting script

* remove function not used

* clean modeling file

* apply suggestion for test file + add convert script to not_doctested

* modify tests according to review - cleaner logic and more tests

* Apply nit suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add checker of valid position embeddings type

* instantiate new layer norm layer with the right eps

* fix freeze_feature_encoder since it can be None in some cases

* add test same output in convert script

* restore wav2vec2conformer and add new model

* create processor and FE + clean

* add new model code

* fix convert script and set default config parameters

* correct model id paths

* make style

* make fix-copies and cleaning files

* fix copied from statements

* complete .md and fixe copies

* clean convert script argument defaults

* fix config parameters docstrings

* fix config docstring

* add copied from and enrich FE tests

* fix copied from and repo-consistency

* add autotokenizer

* make test input length shorter and change docstring code

* fix docstrings and copied from

* add add_adapter to ASR training example

* make testing of adapters more robust

* adapt to multi adapter layers

* refactor input_values->input_features and remove w2v2-bert feature extractor

* remove pretraining model

* remove depreciated features and useless lines

* add copied from and ignore statements to modeling tests

* remove pretraining model #2

* change import in convert script

* change default in convert script

* update readme and remove useless line

* Update tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* refactor BERT to Bert for consistency

* remove useless ignore copy statement

* add persistent to buffer in rotary

* add eps in LayerNorm init and remove copied from

* add adapter activation parameters and add copied from statements

* Fix copied statements and add unitest.skip reasons

* add copied statement in test_processor

* refactor processor

* make style

* replace numpy random by torch rand

* remove expected output CTC

* improve converting script with processor class

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove gumbel class

* remove tests related to previously deleted class

* Update src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* correct typos

* remove uused parameters

* update processor to takes both text and audio

* update checkpoints

* update expected output and add ctc expected output

* add label_attention_mask

* replace pt with np in processor tests

* fix typo

* revert to behaviour with labels_attention_mask

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 5d8eb93e
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_wav2vec2_bert": [
"WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2BertConfig",
],
"processing_wav2vec2_bert": ["Wav2Vec2BertProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_wav2vec2_bert"] = [
"WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2BertForAudioFrameClassification",
"Wav2Vec2BertForCTC",
"Wav2Vec2BertForSequenceClassification",
"Wav2Vec2BertForXVector",
"Wav2Vec2BertModel",
"Wav2Vec2BertPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_wav2vec2_bert import (
WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2BertConfig,
)
from .processing_wav2vec2_bert import Wav2Vec2BertProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_wav2vec2_bert import (
WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2BertForAudioFrameClassification,
Wav2Vec2BertForCTC,
Wav2Vec2BertForSequenceClassification,
Wav2Vec2BertForXVector,
Wav2Vec2BertModel,
Wav2Vec2BertPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Wav2Vec2Bert model configuration"""
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/w2v-bert-2.0": "https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/config.json",
}
class Wav2Vec2BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Wav2Vec2BertModel`]. It is used to
instantiate an Wav2Vec2Bert model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Bert
[facebook/wav2vec2-bert-rel-pos-large](https://huggingface.co/facebook/wav2vec2-bert-rel-pos-large)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*):
Vocabulary size of the Wav2Vec2Bert model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`Wav2Vec2BertModel`]. Vocabulary size of the
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
method of [`Wav2Vec2BertModel`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
feature_projection_input_dim (`int`, *optional*, defaults to 160):
Input dimension of this model, i.e the dimension after processing input audios with [`SeamlessM4TFeatureExtractor`] or [`Wav2Vec2BertProcessor`].
hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for the feature projection.
final_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for the final projection layer of [`Wav2Vec2BertForCTC`].
layerdrop (`float`, *optional*, defaults to 0.1):
The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
details.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
apply_spec_augment (`bool`, *optional*, defaults to `True`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
mask_time_prob (`float`, *optional*, defaults to 0.05):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates `mask_time_prob*len(time_axis)/mask_time_length ``independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2):
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if `mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks`.
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0):
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
instance of [`Wav2Vec2BertForCTC`].
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
of [`Wav2Vec2BertForCTC`].
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of [`Wav2Vec2BertForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 768):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
Dimensionality of the *XVector* embedding vectors.
pad_token_id (`int`, *optional*, defaults to 0): The id of the _beginning-of-stream_ token.
bos_token_id (`int`, *optional*, defaults to 1): The id of the _padding_ token.
eos_token_id (`int`, *optional*, defaults to 2): The id of the _end-of-stream_ token.
add_adapter (`bool`, *optional*, defaults to `False`):
Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very
useful for warm-starting Wav2Vec2Bert for SpeechEncoderDecoder models.
adapter_kernel_size (`int`, *optional*, defaults to 3):
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
adapter_stride (`int`, *optional*, defaults to 2):
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
num_adapter_layers (`int`, *optional*, defaults to 1):
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
True`.
adapter_act (`str` or `function`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the adapter layers. If string, `"gelu"`,
`"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
use_intermediate_ffn_before_adapter (`bool`, *optional*, defaults to `False`):
Whether an intermediate feed-forward block should be stacked on top of the Wav2Vec2Bert Encoder and before the adapter network.
Only relevant if `add_adapter is True`.
output_hidden_size (`int`, *optional*):
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
if `add_adapter is True`.
position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`):
Can be specified to :
- `rotary`, for rotary position embeddings.
- `relative`, for relative position embeddings.
- `relative_key`, for relative position embeddings as defined by Shaw in [Self-Attention
with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
If left to `None`, no relative position embeddings is applied.
rotary_embedding_base (`int`, *optional*, defaults to 10000):
If `"rotary"` position embeddings are used, defines the size of the embedding base.
max_source_positions (`int`, *optional*, defaults to 5000):
if `"relative"` position embeddings are used, defines the maximum source input positions.
left_max_position_embeddings (`int`, *optional*, defaults to 64):
If `"relative_key"` (aka Shaw) position embeddings are used, defines the left clipping value for relative positions.
right_max_position_embeddings (`int`, *optional*, defaults to 8):
If `"relative_key"` (aka Shaw) position embeddings are used, defines the right clipping value for relative positions.
conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
Kernel size of convolutional depthwise 1D layer in Conformer blocks.
conformer_conv_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all convolutional layers in Conformer blocks.
Example:
```python
>>> from transformers import Wav2Vec2BertConfig, Wav2Vec2BertModel
>>> # Initializing a Wav2Vec2Bert facebook/wav2vec2-bert-rel-pos-large style configuration
>>> configuration = Wav2Vec2BertConfig()
>>> # Initializing a model (with random weights) from the facebook/wav2vec2-bert-rel-pos-large style configuration
>>> model = Wav2Vec2BertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "wav2vec2-bert"
def __init__(
self,
vocab_size=None,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
feature_projection_input_dim=160,
hidden_act="swish",
hidden_dropout=0.0,
activation_dropout=0.0,
attention_dropout=0.0,
feat_proj_dropout=0.0,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
apply_spec_augment=True,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=768,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
add_adapter=False,
adapter_kernel_size=3,
adapter_stride=2,
num_adapter_layers=1,
adapter_act="relu",
use_intermediate_ffn_before_adapter=False,
output_hidden_size=None,
position_embeddings_type="relative_key",
rotary_embedding_base=10000,
max_source_positions=5000,
left_max_position_embeddings=64,
right_max_position_embeddings=8,
conv_depthwise_kernel_size=31,
conformer_conv_dropout=0.1,
**kwargs,
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_attention_heads = num_attention_heads
self.feature_projection_input_dim = feature_projection_input_dim
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.feat_proj_dropout = feat_proj_dropout
self.final_dropout = final_dropout
self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.use_weighted_layer_sum = use_weighted_layer_sum
self.max_source_positions = max_source_positions
if position_embeddings_type is not None and position_embeddings_type not in [
"rotary",
"relative",
"relative_key",
]:
raise ValueError(
"""
`position_embeddings_type` is not valid. It must be one of the following values:
`["rotary", "relative", "relative_key"]` or left as `None`.
"""
)
self.position_embeddings_type = position_embeddings_type
self.rotary_embedding_base = rotary_embedding_base
self.left_max_position_embeddings = left_max_position_embeddings
self.right_max_position_embeddings = right_max_position_embeddings
# Conformer-block related
self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
self.conformer_conv_dropout = conformer_conv_dropout
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
# adapter
self.add_adapter = add_adapter
self.adapter_kernel_size = adapter_kernel_size
self.adapter_stride = adapter_stride
self.num_adapter_layers = num_adapter_layers
self.adapter_act = adapter_act
self.output_hidden_size = output_hidden_size if output_hidden_size is not None else hidden_size
if use_intermediate_ffn_before_adapter and not add_adapter:
raise ValueError("`use_intermediate_ffn_before_adapter` is `True` but `add_adapter` is `False`.")
self.use_intermediate_ffn_before_adapter = use_intermediate_ffn_before_adapter
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
# XVector-specific parameters. Feel free to ignore for other classes.
self.tdnn_dim = list(tdnn_dim)
self.tdnn_kernel = list(tdnn_kernel)
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim
@property
def inputs_to_logits_ratio(self):
return functools.reduce(operator.mul, self.conv_stride, 1)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2Bert BERT checkpoint."""
import argparse
import torch
import torchaudio
from fairseq2.data import Collater
from fairseq2.data.audio import WaveformToFbankConverter
from fairseq2.nn.padding import get_seqs_and_padding_mask
from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
from transformers import (
SeamlessM4TFeatureExtractor,
Wav2Vec2BertConfig,
Wav2Vec2BertModel,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
wav2vec_convert_list = [
("encoder_frontend.model_dim_proj", "feature_projection.projection"),
("encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
("encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
("encoder.inner.layers", "encoder.layers"),
("encoder.inner_layer_norm", "encoder.layer_norm"),
("encoder.adaptor_layers", "adapter.layers"),
("inner_proj", "intermediate_dense"),
("self_attn.output_proj", "self_attn.linear_out"),
("output_proj", "output_dense"),
("self_attn.k_proj", "self_attn.linear_k"),
("self_attn.v_proj", "self_attn.linear_v"),
("self_attn.q_proj", "self_attn.linear_q"),
("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"),
("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
("conv.depthwise_conv", "conv_module.depthwise_conv"),
("conv.layer_norm", "conv_module.depthwise_layer_norm"),
("conv_layer_norm", "conv_module.layer_norm"),
("encoder.proj1", "intermediate_ffn.intermediate_dense"),
("encoder.proj2", "intermediate_ffn.output_dense"),
("encoder.layer_norm", "inner_layer_norm"),
("masker.temporal_mask_embed", "masked_spec_embed"),
]
keys_to_remove = {
"quantizer.entry_proj",
"final_proj",
"final_target_proj",
"quantizer.entries",
"quantizer.num_updates",
}
def param_count(model):
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
def _convert_model(
original_model,
hf_model,
convert_list,
):
state_dict = original_model.state_dict()
for k, v in list(state_dict.items()):
new_key = k
for old_layer_name, new_layer_name in convert_list:
if old_layer_name in new_key:
new_key = new_key.replace(old_layer_name, new_layer_name)
# must do it by hand
if ".layer_norm" in new_key and new_key.split(".layer_norm")[0][-1].isnumeric():
new_key = new_key.replace("layer_norm", "final_layer_norm")
add_key = True
for key in keys_to_remove:
if key in new_key:
state_dict.pop(k)
add_key = False
break
if add_key:
state_dict[new_key] = state_dict.pop(k)
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
extra_keys = set({k for k in extra_keys if "num_updates" not in k}) # filter unecessary param
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
hf_model.load_state_dict(state_dict, strict=True)
n_params = param_count(hf_model)
logger.info(f"model loaded: {round(n_params/1e6,1)}M params")
hf_model.eval()
del state_dict
return hf_model
@torch.no_grad()
def convert_wav2vec2_bert_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
config_path=None,
repo_id=None,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = Wav2Vec2BertConfig.from_pretrained(config_path, hidden_act="swish")
else:
config = Wav2Vec2BertConfig(apply_spec_augment=False)
hf_wav2vec = Wav2Vec2BertModel(config)
model = load_conformer_shaw_model(checkpoint_path, dtype=torch.float32)
model.eval()
hf_wav2vec = _convert_model(model, hf_wav2vec, wav2vec_convert_list)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
if repo_id:
hf_wav2vec.push_to_hub(repo_id, create_pr=True)
# save feature extractor
fe = SeamlessM4TFeatureExtractor(padding_value=1)
fe._set_processor_class("Wav2Vec2BertProcessor")
fe.save_pretrained(pytorch_dump_folder_path)
if repo_id:
fe.push_to_hub(repo_id, create_pr=True)
if args.audio_path:
waveform, sample_rate = torchaudio.load(args.audio_path)
waveform = torchaudio.functional.resample(waveform, sample_rate, fe.sampling_rate)
fbank_converter = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=True,
dtype=torch.float32,
)
collater = Collater(pad_value=1)
decoded_audio = {"waveform": waveform.T, "sample_rate": fe.sampling_rate, "format": -1}
src = collater(fbank_converter(decoded_audio))["fbank"]
seqs, padding_mask = get_seqs_and_padding_mask(src)
with torch.inference_mode():
seqs, padding_mask = model.encoder_frontend(seqs, padding_mask)
original_output, padding_mask = model.encoder(seqs, padding_mask)
hf_wav2vec.eval()
inputs = fe(waveform, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = hf_wav2vec(**inputs)
torch.testing.assert_close(original_output, outputs.last_hidden_state, atol=5e-3, rtol=5e-3)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model.",
)
parser.add_argument(
"--checkpoint_path", default="conformer_shaw", type=str, help="Path to seamless communication checkpoint"
)
parser.add_argument(
"--config_path",
default=None,
type=str,
help="Path to hf config.json of model to convert",
)
parser.add_argument("--repo_id", default=None, type=str, help="Push to this repo id if precised.")
parser.add_argument(
"--audio_path",
default=None,
type=str,
help="If specified, check that the original model and the converted model produce the same outputs.",
)
args = parser.parse_args()
convert_wav2vec2_bert_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.repo_id
)
# coding=utf-8
# Copyright 2024 The Seamless Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Wav2Vec2-BERT model."""
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2BertConfig"
# Base docstring
_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 17.04
WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/w2v-bert-2.0",
# See all Wav2Vec2-BERT models at https://huggingface.co/models?filter=wav2vec2-bert
]
# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
"""
Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
stops at the corresponding element in `seq_lens`.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
seq_lens (`torch.Tensor` of shape `(batch)`:
Each element represents the length of the sequence at the same index in `hidden_states`
Returns:
`torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
"""
batch_size, mask_seq_len = hidden_states.shape[:2]
indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
mask = hidden_states.new_ones((batch_size, mask_seq_len))
mask = mask.masked_fill(bool_mask, 0)
return mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
"""
def __init__(self, config):
super().__init__()
dim = config.hidden_size // config.num_attention_heads
base = config.rotary_embedding_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Ignore copy
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.cached_sequence_length = None
self.cached_rotary_positional_embedding = None
def forward(self, hidden_states):
sequence_length = hidden_states.shape[1]
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
return self.cached_rotary_positional_embedding
self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1)
cos_embeddings = embeddings.cos()[:, None, None, :]
sin_embeddings = embeddings.sin()[:, None, None, :]
# Computed embeddings are cast to the dtype of the hidden state inputs
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
return self.cached_rotary_positional_embedding
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
"""Relative positional encoding module."""
def __init__(self, config):
super().__init__()
self.max_len = config.max_source_positions
self.d_model = config.hidden_size
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
def extend_pe(self, x):
# Reset the positional encodings
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` is the position of query vector and `j` is the
# position of key vector. We use positive relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reverse the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, hidden_states: torch.Tensor):
self.extend_pe(hidden_states)
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
relative_position_embeddings = self.pe[:, start_idx:end_idx]
return relative_position_embeddings
class Wav2Vec2BertFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps)
self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states, norm_hidden_states
class Wav2Vec2BertFeedForward(nn.Module):
def __init__(self, config, act_fn=None, hidden_size=None):
super().__init__()
act_fn = act_fn if act_fn is not None else config.hidden_act
hidden_size = hidden_size if hidden_size is not None else config.hidden_size
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn
self.output_dense = nn.Linear(config.intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward.forward
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states)
return hidden_states
class Wav2Vec2BertConvolutionModule(nn.Module):
"""Convolution block used in the conformer block"""
def __init__(self, config):
super().__init__()
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pointwise_conv1 = nn.Conv1d(
config.hidden_size,
2 * config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.glu = nn.GLU(dim=1)
self.depthwise_conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
config.conv_depthwise_kernel_size,
stride=1,
padding=0,
groups=config.hidden_size,
bias=False,
)
self.depthwise_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.activation = ACT2FN[config.hidden_act]
self.pointwise_conv2 = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.dropout = nn.Dropout(config.conformer_conv_dropout)
def forward(self, hidden_states, attention_mask=None):
hidden_states = self.layer_norm(hidden_states)
# Ensure that we do not leak padded positions in depthwise convolution if attention mask is passed.
# Put 0 where necessary
if attention_mask is not None:
hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
# exchange the temporal dimension and the feature dimension
hidden_states = hidden_states.transpose(1, 2)
# GLU mechanism
# => (batch, 2*channel, dim)
hidden_states = self.pointwise_conv1(hidden_states)
# => (batch, channel, dim)
hidden_states = self.glu(hidden_states)
# Pad the sequence entirely on the left because of causal convolution.
hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0))
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
hidden_states = self.activation(hidden_states)
hidden_states = self.pointwise_conv2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2BertSelfAttention(nn.Module):
"""Construct an Wav2Vec2BertSelfAttention object.
Can be enhanced with rotary or relative position embeddings.
"""
def __init__(self, config, is_adapter_attention=False):
super().__init__()
hidden_size = config.hidden_size if not is_adapter_attention else config.output_hidden_size
self.head_size = hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.position_embeddings_type = config.position_embeddings_type if not is_adapter_attention else None
self.linear_q = nn.Linear(hidden_size, hidden_size)
self.linear_k = nn.Linear(hidden_size, hidden_size)
self.linear_v = nn.Linear(hidden_size, hidden_size)
self.linear_out = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(p=config.attention_dropout)
if self.position_embeddings_type == "relative":
# linear transformation for positional encoding
self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
if self.position_embeddings_type == "relative_key":
self.left_max_position_embeddings = config.left_max_position_embeddings
self.right_max_position_embeddings = config.right_max_position_embeddings
num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1
self.distance_embedding = nn.Embedding(num_positions, self.head_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# self-attention mechanism
batch_size, sequence_length, hidden_size = hidden_states.size()
# make sure query/key states can be != value states
query_key_states = hidden_states
value_states = hidden_states
if self.position_embeddings_type == "rotary":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
)
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
# project query_key_states and value_states
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.position_embeddings_type == "relative":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
" 'relative'"
)
# apply relative_position_embeddings to qk scores
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
scores = self._apply_relative_embeddings(
query=query, key=key, relative_position_embeddings=relative_position_embeddings
)
else:
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
if self.position_embeddings_type == "relative_key":
query_length, key_length = query.shape[2], key.shape[2]
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_r - position_ids_l
distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings)
positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
scores = scores + (relative_position_attn_weights / math.sqrt(self.head_size))
# apply attention_mask if necessary
if attention_mask is not None:
scores = scores + attention_mask
# => (batch, head, time1, time2)
probs = torch.softmax(scores, dim=-1)
probs = self.dropout(probs)
# => (batch, head, time1, d_k)
hidden_states = torch.matmul(probs, value)
# => (batch, time1, hidden_size)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
hidden_states = self.linear_out(hidden_states)
return hidden_states, probs
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
batch_size, sequence_length, hidden_size = hidden_states.size()
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
cos = relative_position_embeddings[0, :sequence_length, ...]
sin = relative_position_embeddings[1, :sequence_length, ...]
# rotate hidden_states with rotary embeddings
hidden_states = hidden_states.transpose(0, 1)
rotated_states_begin = hidden_states[..., : self.head_size // 2]
rotated_states_end = hidden_states[..., self.head_size // 2 :]
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
hidden_states = (hidden_states * cos) + (rotated_states * sin)
hidden_states = hidden_states.transpose(0, 1)
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
return hidden_states
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
# 1. project positional embeddings
# => (batch, head, 2*time1-1, d_k)
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
# 2. Add bias to query
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
# 3. attention score: first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# => (batch, head, time1, time2)
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
# 4. then compute matrix b and matrix d
# => (batch, head, time1, 2*time1-1)
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
# 5. shift matrix b and matrix d
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
# 6. sum matrices
# => (batch, head, time1, time2)
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
return scores
class Wav2Vec2BertEncoderLayer(nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
dropout = config.attention_dropout
# Feed-forward 1
self.ffn1_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.ffn1 = Wav2Vec2BertFeedForward(config)
# Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.self_attn_dropout = nn.Dropout(dropout)
self.self_attn = Wav2Vec2BertSelfAttention(config)
# Conformer Convolution
self.conv_module = Wav2Vec2BertConvolutionModule(config)
# Feed-forward 2
self.ffn2_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.ffn2 = Wav2Vec2BertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
conv_attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = hidden_states
# 1. Feed-Forward 1 layer
residual = hidden_states
hidden_states = self.ffn1_layer_norm(hidden_states)
hidden_states = self.ffn1(hidden_states)
hidden_states = hidden_states * 0.5 + residual
residual = hidden_states
# 2. Self-Attention layer
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weigts = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_dropout(hidden_states)
hidden_states = hidden_states + residual
# 3. Convolutional Layer
residual = hidden_states
hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask)
hidden_states = residual + hidden_states
# 4. Feed-Forward 2 Layer
residual = hidden_states
hidden_states = self.ffn2_layer_norm(hidden_states)
hidden_states = self.ffn2(hidden_states)
hidden_states = hidden_states * 0.5 + residual
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weigts
class Wav2Vec2BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.position_embeddings_type == "relative":
self.embed_positions = Wav2Vec2BertRelPositionalEmbedding(config)
elif config.position_embeddings_type == "rotary":
self.embed_positions = Wav2Vec2BertRotaryPositionalEmbedding(config)
else:
self.embed_positions = None
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
conv_attention_mask = attention_mask
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
hidden_states = self.dropout(hidden_states)
if self.embed_positions is not None:
relative_position_embeddings = self.embed_positions(hidden_states)
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
conv_attention_mask,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Wav2Vec2BertAdapter(nn.Module):
def __init__(self, config):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size, eps=config.layer_norm_eps)
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
self.kernel_size = config.adapter_kernel_size
self.stride = config.adapter_stride
def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
if seq_lens is None:
return seq_lens
pad = self.kernel_size // 2
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
return seq_lens.floor()
def forward(self, hidden_states, attention_mask=None):
# down project hidden_states if necessary
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
sub_sampled_lengths = None
if attention_mask is not None:
sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
for layer in self.layers:
layerdrop_prob = torch.rand([])
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
if not self.training or (layerdrop_prob > self.layerdrop):
hidden_states = layer(
hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
)
return hidden_states
class Wav2Vec2BertAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.output_hidden_size
dropout = config.conformer_conv_dropout
self.kernel_size = config.adapter_kernel_size
self.stride = config.adapter_stride
# 1. residual convolution
self.residual_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.residual_conv = nn.Conv1d(
embed_dim,
2 * embed_dim,
self.kernel_size,
stride=self.stride,
padding=self.stride // 2,
)
self.activation = nn.GLU(dim=1)
# Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.self_attn_conv = nn.Conv1d(
embed_dim,
2 * embed_dim,
self.kernel_size,
stride=self.stride,
padding=self.stride // 2,
)
self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True)
self.self_attn_dropout = nn.Dropout(dropout)
# Feed-forward
self.ffn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim)
def forward(
self,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
sub_sampled_lengths: Optional[torch.Tensor] = None,
):
residual = self.residual_layer_norm(hidden_states)
# Apply pooling to the residual to match the sequence length of the
# multi-head attention output.
# (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
residual = residual.transpose(1, 2)
residual = self.residual_conv(residual)
residual = self.activation(residual)
# (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
residual = residual.transpose(1, 2)
hidden_states = self.self_attn_layer_norm(hidden_states)
# Apply pooling before feeding to the multihead-attention layer.
# (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.self_attn_conv(hidden_states)
hidden_states = self.activation(hidden_states)
# (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
hidden_states = hidden_states.transpose(1, 2)
if attention_mask is not None:
attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
attention_mask = _prepare_4d_attention_mask(
attention_mask,
hidden_states.dtype,
)
# The rest of the computation is identical to a vanilla Transformer
# encoder layer.
hidden_states, attn_weigths = self.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_dropout(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.ffn_layer_norm(hidden_states)
hidden_states = self.ffn(hidden_states) + residual
return hidden_states
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerPreTrainedModel with Wav2Vec2Conformer->Wav2Vec2Bert,wav2vec2_conformer->wav2vec2_bert, input_values->input_features
class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2BertConfig
base_model_prefix = "wav2vec2_bert"
main_input_name = "input_features"
supports_gradient_checkpointing = True
# Ignore copy
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Wav2Vec2BertSelfAttention):
if hasattr(module, "pos_bias_u"):
nn.init.xavier_uniform_(module.pos_bias_u)
if hasattr(module, "pos_bias_v"):
nn.init.xavier_uniform_(module.pos_bias_v)
elif isinstance(module, Wav2Vec2BertFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
# Ignore copy
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride, padding):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length + 2 * padding - kernel_size, stride, rounding_mode="floor") + 1
if add_adapter:
padding = self.config.adapter_kernel_size // 2
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(
input_lengths, self.config.adapter_kernel_size, self.config.adapter_stride, padding
)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
WAV2VEC2_BERT_START_DOCSTRING = r"""
Wav2Vec2Bert was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config ([`Wav2Vec2BertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
def __init__(self, config: Wav2Vec2BertConfig):
super().__init__(config)
self.config = config
self.feature_projection = Wav2Vec2BertFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
self.encoder = Wav2Vec2BertEncoder(config)
self.adapter = Wav2Vec2BertAdapter(config) if config.add_adapter else None
self.intermediate_ffn = None
if config.use_intermediate_ffn_before_adapter:
self.intermediate_ffn = Wav2Vec2BertFeedForward(config, act_fn="relu")
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_features: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states, extract_features = self.feature_projection(input_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.intermediate_ffn:
expanded_hidden_states = self.intermediate_ffn(hidden_states)
hidden_states = hidden_states + 0.5 * expanded_hidden_states
if self.adapter is not None:
hidden_states = self.adapter(hidden_states, attention_mask=attention_mask)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForCTC.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.wav2vec2_bert = Wav2Vec2BertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(
self,
input_features: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2_bert(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask
if attention_mask is not None
else torch.ones(input_features.shape[:2], device=input_features.device, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum([-1])).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(
"""
Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for
tasks like SUPERB Keyword Spotting.
""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of Wav2Vec2Bert adapters (config.add_adapter=True)"
)
self.wav2vec2_bert = Wav2Vec2BertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_bert.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_BASE_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_bert(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Wav2Vec2Bert Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Audio frame classification does not support the use of Wav2Vec2Bert adapters (config.add_adapter=True)"
)
self.wav2vec2_bert = Wav2Vec2BertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights()
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_bert.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_BASE_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_bert(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
self.loss = nn.CrossEntropyLoss()
def forward(self, hidden_states, labels):
labels = labels.flatten()
weight = nn.functional.normalize(self.weight, dim=0)
hidden_states = nn.functional.normalize(hidden_states, dim=1)
cos_theta = torch.mm(hidden_states, weight)
psi = cos_theta - self.margin
onehot = nn.functional.one_hot(labels, self.num_labels)
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
loss = self.loss(logits, labels)
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
self.out_conv_dim = config.tdnn_dim[layer_id]
self.kernel_size = config.tdnn_kernel[layer_id]
self.dilation = config.tdnn_dilation[layer_id]
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()
def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@add_start_docstrings(
"""
Wav2Vec2Bert Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
self.wav2vec2_bert = Wav2Vec2BertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
self.tdnn = nn.ModuleList(tdnn_layers)
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
self.init_weights()
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_bert.parameters():
param.requires_grad = False
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector._get_tdnn_output_lengths
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size in self.config.tdnn_kernel:
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
return input_lengths
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_BASE_CHECKPOINT_FOR_DOC,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, XVectorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_bert(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
for tdnn_layer in self.tdnn:
hidden_states = tdnn_layer(hidden_states)
# Statistic Pooling
if attention_mask is None:
mean_features = hidden_states.mean(dim=1)
std_features = hidden_states.std(dim=1)
else:
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
mean_features = []
std_features = []
for i, length in enumerate(tdnn_output_lengths):
mean_features.append(hidden_states[i, :length].mean(dim=0))
std_features.append(hidden_states[i, :length].std(dim=0))
mean_features = torch.stack(mean_features)
std_features = torch.stack(std_features)
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
output_embeddings = self.feature_extractor(statistic_pooling)
logits = self.classifier(output_embeddings)
loss = None
if labels is not None:
loss = self.objective(logits, labels)
if not return_dict:
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return XVectorOutput(
loss=loss,
logits=logits,
embeddings=output_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Speech processor class for Wav2Vec2-BERT
"""
import warnings
from ...processing_utils import ProcessorMixin
from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
class Wav2Vec2BertProcessor(ProcessorMixin):
r"""
Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single
processor.
[`Wav2Vec2Processor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and [`PreTrainedTokenizer`].
See the docstring of [`~Wav2Vec2Processor.__call__`] and [`~Wav2Vec2Processor.decode`] for more information.
Args:
feature_extractor (`SeamlessM4TFeatureExtractor`):
An instance of [`SeamlessM4TFeatureExtractor`]. The feature extractor is a required input.
tokenizer ([`PreTrainedTokenizer`]):
An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input.
"""
feature_extractor_class = "SeamlessM4TFeatureExtractor"
tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
try:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
except OSError:
warnings.warn(
f"Loading a tokenizer inside {cls.__name__} from a config that does not"
" include a `tokenizer_class` attribute is deprecated and will be "
"removed in v5. Please add `'tokenizer_class': 'Wav2Vec2CTCTokenizer'`"
" attribute to either your `config.json` or `tokenizer_config.json` "
"file to suppress this warning: ",
FutureWarning,
)
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
def __call__(self, audio=None, text=None, **kwargs):
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio`
and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not
`None` to pre-process the audio. To prepare the target sequences(s), this method forwards the `text` and `kwargs` arguments to
PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
and T the sample length of the audio.
kwargs (*optional*):
Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
tokenizer.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`.
- **attention_mask** -- List of indices specifying which timestamps should be attended to by the model when `audio` is not `None`.
When only `text` is specified, returns the token attention mask.
- **labels** -- List of token ids to be fed to a model. Returned when both `text` and `audio` are not `None`.
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`.
"""
sampling_rate = kwargs.pop("sampling_rate", None)
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def pad(self, input_features=None, labels=None, **kwargs):
"""
If `input_features` is not `None`, this method forwards the `input_features` and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.pad`] to pad the input features.
If `labels` is not `None`, this method forwards the `labels` and `kwargs` arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.pad`] to pad the label(s).
Please refer to the doctsring of the above two methods for more information.
"""
if input_features is None and labels is None:
raise ValueError("You need to specify either an `input_features` or `labels` input to pad.")
if input_features is not None:
input_features = self.feature_extractor.pad(input_features, **kwargs)
if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs)
if labels is None:
return input_features
elif input_features is None:
return labels
else:
input_features["labels"] = labels["input_ids"]
return input_features
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
...@@ -8758,6 +8758,51 @@ class Wav2Vec2PreTrainedModel(metaclass=DummyObject): ...@@ -8758,6 +8758,51 @@ class Wav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Wav2Vec2BertForAudioFrameClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2BertForCTC(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2BertForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2BertForXVector(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2BertModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2BertPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2-BERT model. """
import tempfile
import unittest
from datasets import load_dataset
from transformers import Wav2Vec2BertConfig, is_torch_available
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_torch,
require_torch_accelerator,
require_torch_fp16,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
AutoFeatureExtractor,
Wav2Vec2BertForAudioFrameClassification,
Wav2Vec2BertForCTC,
Wav2Vec2BertForSequenceClassification,
Wav2Vec2BertForXVector,
Wav2Vec2BertModel,
)
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
_compute_mask_indices,
_sample_negative_indices,
)
# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerModelTester with Conformer->Bert, input_values->input_features
class Wav2Vec2BertModelTester:
# Ignore copy
def __init__(
self,
parent,
batch_size=13,
seq_length=200, # speech is longer
is_training=False,
hidden_size=16,
feature_projection_input_dim=16,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=2,
num_attention_heads=2,
hidden_dropout_prob=0.1,
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
num_adapter_layers=2,
adapter_stride=2,
tdnn_dim=(32, 32),
tdnn_kernel=(5, 3),
tdnn_dilation=(1, 2),
xvector_output_dim=32,
position_embeddings_type="relative",
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feature_projection_input_dim = feature_projection_input_dim
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.num_adapter_layers = num_adapter_layers
self.adapter_stride = adapter_stride
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
self.position_embeddings_type = position_embeddings_type
self.output_seq_length = self.seq_length
self.encoder_seq_length = self.output_seq_length
self.adapter_output_seq_length = self.output_seq_length
for _ in range(num_adapter_layers):
self.adapter_output_seq_length = (self.adapter_output_seq_length - 1) // adapter_stride + 1
# Ignore copy
def prepare_config_and_inputs(self, position_embeddings_type="relative"):
input_shape = [self.batch_size, self.seq_length, self.feature_projection_input_dim]
input_features = floats_tensor(input_shape, self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config(position_embeddings_type=position_embeddings_type)
return config, input_features, attention_mask
# Ignore copy
def get_config(self, position_embeddings_type="relative"):
return Wav2Vec2BertConfig(
hidden_size=self.hidden_size,
feature_projection_input_dim=self.feature_projection_input_dim,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
num_adapter_layers=self.num_adapter_layers,
adapter_stride=self.adapter_stride,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
position_embeddings_type=position_embeddings_type,
)
def create_and_check_model(self, config, input_features, attention_mask):
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter(self, config, input_features, attention_mask):
config.add_adapter = True
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter_for_ctc(self, config, input_features, attention_mask):
config.add_adapter = True
config.output_hidden_size = 2 * config.hidden_size
model = Wav2Vec2BertForCTC(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
)
# Ignore copy
def create_and_check_model_with_intermediate_ffn_before_adapter(self, config, input_features, attention_mask):
config.add_adapter = True
config.use_intermediate_ffn_before_adapter = True
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
# also try with different adapter proj dim
config.output_hidden_size = 8
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_model_with_adapter_proj_dim(self, config, input_features, attention_mask):
config.add_adapter = True
config.output_hidden_size = 8
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_model_float16(self, config, input_features, attention_mask):
model = Wav2Vec2BertModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = Wav2Vec2BertModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_features.type(dtype=torch.float16), attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_features, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
input_features = input_features[:3]
attention_mask = torch.ones(input_features.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
for i in range(input_features.shape[0]):
input_slice = input_features[i : i + 1, : input_lengths[i]]
output = model(input_slice).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_features, *args):
model = Wav2Vec2BertForCTC(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_features = input_features[:3]
# Ignore copy
attention_mask = torch.ones(input_features.shape[:2], device=torch_device, dtype=torch.long)
input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_features.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float))
def check_seq_classifier_loss(self, config, input_features, *args):
model = Wav2Vec2BertForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_features = input_features[:3]
# Ignore copy
attention_mask = torch.ones(input_features.shape[:2], device=torch_device, dtype=torch.long)
input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_features, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_features, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2BertForCTC(config=config)
model.to(torch_device)
model.train()
# Ignore copy
input_features = input_features[:3]
input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
if max_length_labels[i] < labels.shape[-1]:
# it's important that we make sure that target lengths are at least
# one shorter than logit lengths to prevent -inf
labels[i, max_length_labels[i] - 1 :] = -100
loss = model(input_features, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_seq_classifier_training(self, config, input_features, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2BertForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_features = input_features[:3]
# Ignore copy
input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
loss = model(input_features, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_xvector_training(self, config, input_features, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2BertForXVector(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_features = input_features[:3]
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
loss = model(input_features, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_features, *args):
model = Wav2Vec2BertForCTC(config)
model.to(torch_device)
model.train()
input_features = input_features[:3]
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
with self.parent.assertRaises(ValueError):
model(input_features, labels=labels)
def prepare_config_and_inputs_for_common(self):
config, input_features, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_features": input_features, "attention_mask": attention_mask}
return config, inputs_dict
@require_torch
# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerModelTest with Conformer->Bert, input_values->input_features
class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# Ignore copy
all_model_classes = (
(
Wav2Vec2BertForCTC,
Wav2Vec2BertModel,
Wav2Vec2BertForSequenceClassification,
Wav2Vec2BertForAudioFrameClassification,
Wav2Vec2BertForXVector,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"audio-classification": Wav2Vec2BertForSequenceClassification,
"automatic-speech-recognition": Wav2Vec2BertForCTC,
"feature-extraction": Wav2Vec2BertModel,
}
if is_torch_available()
else {}
)
test_pruning = False
test_headmasking = False
test_torchscript = False
def setUp(self):
self.model_tester = Wav2Vec2BertModelTester(self)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2BertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model(*config_and_inputs)
# Ignore copy
def test_model_with_relative_key(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative_key")
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_no_rel_pos(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type=None)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
def test_model_with_adapter_for_ctc(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
# Ignore copy
def test_model_with_intermediate_ffn_before_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_intermediate_ffn_before_adapter(*config_and_inputs)
def test_model_with_adapter_proj_dim(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
@require_torch_accelerator
@require_torch_fp16
def test_model_float16_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
# Ignore copy
@require_torch_accelerator
@require_torch_fp16
def test_model_float16_with_relative_key(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative_key")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
@require_torch_accelerator
@require_torch_fp16
def test_model_float16_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_xvector_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_xvector_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Ignore copy
@unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
def test_inputs_embeds(self):
pass
# Ignore copy
@unittest.skip(reason="`input_ids` is renamed to `input_features`")
def test_forward_signature(self):
pass
# Ignore copy
@unittest.skip(reason="Wav2Vec2Bert has no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
# Ignore copy
@unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
def test_model_common_attributes(self):
pass
# Ignore copy
@unittest.skip(reason="non-robust architecture does not exist in Flax")
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
pass
# Ignore copy
@unittest.skip(reason="non-robust architecture does not exist in Flax")
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_features = inputs_dict["input_features"]
input_lengths = torch.tensor(
[input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"pos_bias_v",
"pos_bias_u",
"pointwise_conv1",
"pointwise_conv2",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"objective.weight",
]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
module.pos_bias_u.data.fill_(3)
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
module.pos_bias_v.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
# Ignore copy
@unittest.skip(reason="Kept to make #Copied from working")
def test_mask_feature_prob_ctc(self):
pass
# Ignore copy
@unittest.skip(reason="Kept to make #Copied from working")
def test_mask_time_prob_ctc(self):
pass
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
# Ignore copy
model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
self.assertIsNotNone(model)
@require_torch
# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerUtilsTest with Conformer->Bert, input_values->input_features
class Wav2Vec2BertUtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_low_prob(self):
# with these settings num_masked_spans=0.5, which means probabilistic rounding
# ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
# the other 5 out of 10, cases num_masked_spans=1
n_trials = 100
batch_size = 4
sequence_length = 100
mask_prob = 0.05
mask_length = 10
count_dimensions_masked = 0
count_dimensions_not_masked = 0
for _ in range(n_trials):
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
num_masks = torch.sum(mask).item()
if num_masks > 0:
count_dimensions_masked += 1
else:
count_dimensions_not_masked += 1
# as we test for at least 10 masked dimension and at least
# 10 non-masked dimension, this test could fail with probability:
# P(100 coin flips, at most 9 heads) = 1.66e-18
self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
mask = torch.from_numpy(mask).to(torch_device)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_mask_indices_short_audio(self):
batch_size = 4
sequence_length = 100
mask_prob = 0.05
mask_length = 10
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
# force one example to be heavily padded
attention_mask[0, 5:] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
)
# make sure that non-padded examples cannot be padded
self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
# Ignore copy
@unittest.skip(reason="Kept to make #Copied from working. Test a class used for pretraining, not yet supported.")
def test_compute_perplexity(self):
pass
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# sample negative indices
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue((negatives >= 0).all().item())
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch
@slow
class Wav2Vec2BertModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
speech_samples = speech_samples[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_inference_w2v2_bert(self):
model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
model.to(torch_device)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
input_speech = self._load_datasamples(2)
inputs = feature_extractor(input_speech, return_tensors="pt", padding=True).to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# fmt: off
expected_slice_0 = torch.tensor(
[[-0.0098, -0.0570, -0.1286, 0.0439, -0.1037, -0.0235],
[-0.0767, 0.0574, -0.3224, 0.0482, 0.0440, -0.0193],
[ 0.0220, -0.0878, -0.2027, -0.0028, -0.0666, 0.0721],
[ 0.0307, -0.1099, 0.0273, -0.0416, -0.0715, 0.0094],
[ 0.0758, -0.0291, 0.1084, 0.0004, -0.0751, -0.0116],
[ 0.0349, -0.0343, -0.0098, 0.0415, -0.0617, 0.0241],
[-0.0193, -0.0171, 0.1965, 0.0797, -0.0308, 0.2033],
[-0.0323, -0.0315, 0.0948, 0.0944, -0.0254, 0.1241],
[-0.0493, 0.0010, -0.1762, 0.0034, -0.0787, 0.0832],
[ 0.0043, -0.1228, -0.0739, 0.0266, -0.0337, -0.0068]]
).to(torch_device)
# fmt: on
# fmt: off
expected_slice_1 = torch.tensor(
[[-0.0348, -0.0521, -0.3036, 0.0285, -0.0715, -0.0453],
[-0.0102, 0.0114, -0.3266, 0.0027, -0.0558, 0.0038],
[ 0.0454, 0.0148, -0.2418, -0.0392, -0.0455, 0.0478],
[-0.0013, 0.0825, -0.1730, -0.0091, -0.0426, 0.0360],
[-0.0227, 0.0687, -0.1168, 0.0569, -0.0160, 0.0759],
[-0.0318, 0.0562, -0.0508, 0.0605, 0.0150, 0.0953],
[-0.0415, 0.0438, 0.0233, 0.0336, 0.0262, 0.0860],
[-0.0163, 0.0048, 0.0807, 0.0119, 0.0712, 0.0158],
[ 0.0244, -0.0145, 0.0262, -0.0237, 0.0283, -0.0125],
[-0.0587, -0.0516, -0.0368, -0.0196, 0.0307, -0.1434]]
).to(torch_device)
# fmt: on
self.assertTrue((outputs.last_hidden_state[0, 25:35, 4:10] - expected_slice_0).abs().max() <= 1e-4)
self.assertTrue((outputs.last_hidden_state[1, 25:35, 4:10] - expected_slice_1).abs().max() <= 1e-4)
self.assertAlmostEqual(outputs.last_hidden_state[1].mean().item(), 3.3123e-05)
self.assertAlmostEqual(outputs.last_hidden_state[1].std().item(), 0.1545, delta=2e-5)
self.assertListEqual(list(outputs.last_hidden_state.shape), [2, 326, 1024])
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import unittest
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor
from transformers.utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
# Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor
class Wav2Vec2BertProcessorTest(unittest.TestCase):
def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.add_kwargs_tokens_map = {
"pad_token": "<pad>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}
feature_extractor_map = {
"feature_size": 1,
"padding_value": 0.0,
"sampling_rate": 16000,
"return_attention_mask": False,
"do_normalize": True,
}
self.tmpdirname = tempfile.mkdtemp()
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
def get_tokenizer(self, **kwargs_init):
kwargs = self.add_kwargs_tokens_map.copy()
kwargs.update(kwargs_init)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_feature_extractor(self, **kwargs):
return SeamlessM4TFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = Wav2Vec2BertProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = Wav2Vec2BertProcessor(
tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = Wav2Vec2BertProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
...@@ -762,6 +762,7 @@ OBJECTS_TO_IGNORE = [ ...@@ -762,6 +762,7 @@ OBJECTS_TO_IGNORE = [
"VitMatteForImageMatting", "VitMatteForImageMatting",
"VitsTokenizer", "VitsTokenizer",
"VivitModel", "VivitModel",
"Wav2Vec2BertForCTC",
"Wav2Vec2CTCTokenizer", "Wav2Vec2CTCTokenizer",
"Wav2Vec2Config", "Wav2Vec2Config",
"Wav2Vec2ConformerConfig", "Wav2Vec2ConformerConfig",
......
...@@ -878,6 +878,7 @@ src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to ...@@ -878,6 +878,7 @@ src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to
src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment