Unverified Commit 5a995735 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Wav2Vec2Conformer (#16812)



* save intermediate

* add wav2vec2 conformer

* add more code

* more

* first test passes

* make all checkpoints work

* update

* up

* more clean ups

* save clean-up

* save clean-up

* save more

* remove bogus

* finalize design conformer

* remove vision

* finish all tests

* more changes

* finish code

* add doc tests

* add slow tests

* fix autoconfig test

* up

* correct docstring

* up

* update

* fix

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>

* Update docs/source/en/model_doc/wav2vec2-conformer.mdx

* upload

* save copied from

* correct configs

* fix model outputs

* add to docs

* fix imports

* finish

* finish code

* correct copied from

* correct again

* correct make fix

* improve make fix copies

* save

* correct fix copy from

* correct init structure

* correct

* fix import

* apply suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>
parent f0395cf5
......@@ -27,7 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
......@@ -71,35 +71,6 @@ UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class UniSpeechBaseModelOutput(ModelOutput):
"""
Output type of [`UniSpeechBaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
"""
......@@ -1158,7 +1129,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=UniSpeechBaseModelOutput,
output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -1171,7 +1142,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechBaseModelOutput]:
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1203,7 +1174,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return UniSpeechBaseModelOutput(
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
......
......@@ -81,13 +81,13 @@ class UniSpeechSatConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
......@@ -159,13 +159,13 @@ class UniSpeechSatConfig(PretrainedConfig):
instance of [`UniSpeechSatForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
......
......@@ -27,7 +27,14 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
......@@ -77,35 +84,6 @@ UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class UniSpeechSatBaseModelOutput(ModelOutput):
"""
Output type of [`UniSpeechSatBaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
"""
......@@ -143,38 +121,6 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ForXVector`].
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
......@@ -1198,7 +1144,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=UniSpeechSatBaseModelOutput,
output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -1211,7 +1157,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1243,7 +1189,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return UniSpeechSatBaseModelOutput(
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
......
......@@ -78,13 +78,13 @@ class Wav2Vec2Config(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
......@@ -156,13 +156,13 @@ class Wav2Vec2Config(PretrainedConfig):
instance of [`Wav2Vec2ForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
......
......@@ -33,6 +33,8 @@ from ...modeling_outputs import (
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
......@@ -88,35 +90,6 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of [`Wav2Vec2BaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
"""
......@@ -159,38 +132,6 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ForXVector`].
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
......@@ -1025,11 +966,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
......@@ -1473,13 +1411,13 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
```python
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_wav2vec2_conformer": [
"WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2ConformerConfig",
],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_wav2vec2_conformer"] = [
"WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ConformerForAudioFrameClassification",
"Wav2Vec2ConformerForCTC",
"Wav2Vec2ConformerForPreTraining",
"Wav2Vec2ConformerForSequenceClassification",
"Wav2Vec2ConformerForXVector",
"Wav2Vec2ConformerModel",
"Wav2Vec2ConformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_wav2vec2_conformer import (
WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2ConformerConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_wav2vec2_conformer import (
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ConformerForAudioFrameClassification,
Wav2Vec2ConformerForCTC,
Wav2Vec2ConformerForPreTraining,
Wav2Vec2ConformerForSequenceClassification,
Wav2Vec2ConformerForXVector,
Wav2Vec2ConformerModel,
Wav2Vec2ConformerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Wav2Vec2Conformer model configuration"""
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/wav2vec2-conformer-large-rel-pos": (
"https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos/resolve/main/config.json"
),
}
class Wav2Vec2ConformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Wav2Vec2ConformerModel`]. It is used to
instantiate an Wav2Vec2Conformer model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer
[facebook/wav2vec2-conformer-large-rel-pos](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*):
Vocabulary size of the Wav2Vec2Conformer model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`Wav2Vec2ConformerModel`]. Vocabulary size of the
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
method of [`Wav2Vec2ConformerModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
final_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for the final projection layer of [`Wav2Vec2ConformerForCTC`].
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
convolutional layers.
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for output of the feature encoder.
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
embeddings layer.
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
apply_spec_augment (`bool`, *optional*, defaults to `True`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
mask_time_prob (`float`, *optional*, defaults to 0.05):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2),:
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks''
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
num_codevectors_per_group (`int`, *optional*, defaults to 320):
Number of entries in each quantization codebook (group).
num_codevector_groups (`int`, *optional*, defaults to 2):
Number of codevector groups for product codevector quantization.
contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
The temperature *kappa* in the contrastive loss.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
num_negatives (`int`, *optional*, defaults to 100):
Number of negative samples for the contrastive loss.
codevector_dim (`int`, *optional*, defaults to 256):
Dimensionality of the quantized feature vectors.
proj_codevector_dim (`int`, *optional*, defaults to 256):
Dimensionality of the final projection of both the quantized and the transformer features.
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
The weight of the codebook diversity loss component.
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
instance of [`Wav2Vec2ConformerForCTC`].
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
of [`Wav2Vec2ConformerForCTC`].
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of [`Wav2Vec2ConformerForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
Dimensionality of the *XVector* embedding vectors.
add_adapter (`bool`, *optional*, defaults to `False`):
Whether a convolutional network should be stacked on top of the Wav2Vec2Conformer Encoder. Can be very
useful for warm-starting Wav2Vec2Conformer for SpeechEncoderDecoder models.
adapter_kernel_size (`int`, *optional*, defaults to 3):
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
adapter_stride (`int`, *optional*, defaults to 2):
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
num_adapter_layers (`int`, *optional*, defaults to 3):
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
True`.
output_hidden_size (`int`, *optional*):
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
if `add_adapter is True`.
position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
`None` no relative position embedding is applied.
rotary_embedding_base (`int`, *optional*, defaults to 10000):
If `"rotary"` position embeddings are used, defines the size of the embedding base.
max_source_positions (`int`, *optional*, defaults to 5000):
if `"relative"` position embeddings are used, defines the maximum source input positions.
conv_depthwise_kernel_size (`int`, defaults to 31):
Kernel size of convolutional depthwise 1D layer in Conformer blocks.
conformer_conv_dropout (`float`, defaults to 0.1):
The dropout probability for all convolutional layers in Conformer blocks.
Example:
```python
>>> from transformers import Wav2Vec2ConformerModel, Wav2Vec2ConformerConfig
>>> # Initializing a Wav2Vec2Conformer facebook/wav2vec2-conformer-large-rel-pos style configuration
>>> configuration = Wav2Vec2ConformerConfig()
>>> # Initializing a model from the facebook/wav2vec2-conformer-large-rel-pos style configuration
>>> model = Wav2Vec2ConformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "wav2vec2-conformer"
def __init__(
self,
vocab_size=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.0,
feat_quantizer_dropout=0.0,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_activation="gelu",
conv_dim=(512, 512, 512, 512, 512, 512, 512),
conv_stride=(5, 2, 2, 2, 2, 2, 2),
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
apply_spec_augment=True,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
num_codevectors_per_group=320,
num_codevector_groups=2,
contrastive_logits_temperature=0.1,
num_negatives=100,
codevector_dim=256,
proj_codevector_dim=256,
diversity_loss_weight=0.1,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
add_adapter=False,
adapter_kernel_size=3,
adapter_stride=2,
num_adapter_layers=3,
output_hidden_size=None,
position_embeddings_type="relative",
rotary_embedding_base=10000,
max_source_positions=5000,
conv_depthwise_kernel_size=31,
conformer_conv_dropout=0.1,
**kwargs
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_activation = feat_extract_activation
self.conv_dim = list(conv_dim)
self.conv_stride = list(conv_stride)
self.conv_kernel = list(conv_kernel)
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_feat_extract_layers = len(self.conv_dim)
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_attention_heads = num_attention_heads
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.feat_proj_dropout = feat_proj_dropout
self.final_dropout = final_dropout
self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.use_weighted_layer_sum = use_weighted_layer_sum
self.max_source_positions = max_source_positions
self.position_embeddings_type = position_embeddings_type
self.rotary_embedding_base = rotary_embedding_base
if (
(len(self.conv_stride) != self.num_feat_extract_layers)
or (len(self.conv_kernel) != self.num_feat_extract_layers)
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# Conformer-block related
self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
self.conformer_conv_dropout = conformer_conv_dropout
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
# parameters for pretraining with codevector quantized representations
self.num_codevectors_per_group = num_codevectors_per_group
self.num_codevector_groups = num_codevector_groups
self.contrastive_logits_temperature = contrastive_logits_temperature
self.feat_quantizer_dropout = feat_quantizer_dropout
self.num_negatives = num_negatives
self.codevector_dim = codevector_dim
self.proj_codevector_dim = proj_codevector_dim
self.diversity_loss_weight = diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
# adapter
self.add_adapter = add_adapter
self.adapter_kernel_size = adapter_kernel_size
self.adapter_stride = adapter_stride
self.num_adapter_layers = num_adapter_layers
self.output_hidden_size = output_hidden_size or hidden_size
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
# XVector-specific parameters. Feel free to ignore for other classes.
self.tdnn_dim = list(tdnn_dim)
self.tdnn_kernel = list(tdnn_kernel)
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim
@property
def inputs_to_logits_ratio(self):
return functools.reduce(operator.mul, self.conv_stride, 1)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2Conformer checkpoint."""
import argparse
import json
import os
import fairseq
import torch
from fairseq.data import Dictionary
from transformers import (
Wav2Vec2ConformerConfig,
Wav2Vec2ConformerForCTC,
Wav2Vec2ConformerForPreTraining,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.linear_k": "encoder.layers.*.self_attn.linear_k",
"self_attn.linear_v": "encoder.layers.*.self_attn.linear_v",
"self_attn.linear_q": "encoder.layers.*.self_attn.linear_q",
"self_attn.pos_bias_u": "encoder.layers.*.self_attn.pos_bias_u",
"self_attn.pos_bias_v": "encoder.layers.*.self_attn.pos_bias_v",
"self_attn.linear_out": "encoder.layers.*.self_attn.linear_out",
"self_attn.linear_pos": "encoder.layers.*.self_attn.linear_pos",
"self_attn.rotary_emb": "encoder.embed_positions",
"self_attn_layer_norm": "encoder.layers.*.self_attn_layer_norm",
"conv_module.pointwise_conv1": "encoder.layers.*.conv_module.pointwise_conv1",
"conv_module.pointwise_conv2": "encoder.layers.*.conv_module.pointwise_conv2",
"conv_module.depthwise_conv": "encoder.layers.*.conv_module.depthwise_conv",
"conv_module.batch_norm": "encoder.layers.*.conv_module.batch_norm",
"conv_module.layer_norm": "encoder.layers.*.conv_module.layer_norm",
"ffn1.w_1": "encoder.layers.*.ffn1.intermediate_dense",
"ffn1.w_2": "encoder.layers.*.ffn1.output_dense",
"ffn1.layer_norm": "encoder.layers.*.ffn1_layer_norm",
"ffn2.w_1": "encoder.layers.*.ffn2.intermediate_dense",
"ffn2.w_2": "encoder.layers.*.ffn2.output_dense",
"ffn2.layer_norm": "encoder.layers.*.ffn2_layer_norm",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]
def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
if hf_shape != value.shape:
raise ValueError(
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
f" {value.shape} for {full_name}"
)
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "running_mean":
hf_pointer.running_mean.data = value
elif weight_type == "running_var":
hf_pointer.running_var.data = value
elif weight_type == "num_batches_tracked":
hf_pointer.num_batches_tracked.data = value
elif weight_type == "inv_freq":
hf_pointer.inv_freq.data = value
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2_conformer.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2_conformer." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "pos_bias_u" in name:
weight_type = None
elif "pos_bias_v" in name:
weight_type = None
elif "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
# TODO: don't match quantizer.weight_proj
weight_type = "weight"
elif "running_mean" in name:
weight_type = "running_mean"
elif "inv_freq" in name:
weight_type = "inv_freq"
elif "running_var" in name:
weight_type = "running_var"
elif "num_batches_tracked" in name:
weight_type = "num_batches_tracked"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
unused_weights.append(name)
logger.warning(f"Unused weights: {unused_weights}")
# Copied from transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.load_conv_layer
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])
if type_id == 0:
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)
@torch.no_grad()
def convert_wav2vec2_conformer_checkpoint(
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = Wav2Vec2ConformerConfig.from_pretrained(config_path, hidden_act="swish")
else:
config = Wav2Vec2ConformerConfig()
if "rope" in checkpoint_path:
config.position_embeddings_type = "rotary"
if is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
return
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
vocab_dict = target_dict.indices
# fairseq has the <pad> and <s> switched
vocab_dict["<pad>"] = 0
vocab_dict["<s>"] = 1
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
json.dump(vocab_dict, vocab_handle)
tokenizer = Wav2Vec2CTCTokenizer(
vocab_path,
unk_token=target_dict.unk_word,
pad_token=target_dict.pad_word,
bos_token=target_dict.bos_word,
eos_token=target_dict.eos_word,
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)
hf_wav2vec = Wav2Vec2ConformerForCTC(config)
else:
hf_wav2vec = Wav2Vec2ConformerForPreTraining(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)
args = parser.parse_args()
convert_wav2vec2_conformer_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
)
# coding=utf-8
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Wav2Vec2-Conformer model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 64.21
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-seq-class"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
_SEQ_CLASS_EXPECTED_LOSS = 0.68
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-frame-class"
_FRAME_EXPECTED_OUTPUT = [1, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-xvector"
_XVECTOR_EXPECTED_OUTPUT = 1.0
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-conformer-large-rel-pos",
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
]
@dataclass
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
"""
loss: Optional[torch.FloatTensor] = None
projected_states: torch.FloatTensor = None
projected_quantized_states: torch.FloatTensor = None
codevector_perplexity: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
contrastive_loss: Optional[torch.FloatTensor] = None
diversity_loss: Optional[torch.FloatTensor] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
"""
def __init__(self, config):
super().__init__()
dim = config.hidden_size // config.num_attention_heads
base = config.rotary_embedding_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None
self.cached_rotary_positional_embedding = None
def forward(self, hidden_states):
sequence_length = hidden_states.shape[1]
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
return self.cached_rotary_positional_embedding
self.cached_sequence_length = sequence_length
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1)
cos_embeddings = embeddings.cos()[:, None, None, :]
sin_embeddings = embeddings.sin()[:, None, None, :]
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
return self.cached_rotary_positional_embedding
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
"""Relative positional encoding module."""
def __init__(self, config):
super().__init__()
self.max_len = config.max_source_positions
self.d_model = config.hidden_size
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
def extend_pe(self, x):
# Reset the positional encodings
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` is the position of query vector and `j` is the
# position of key vector. We use positive relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reverse the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, hidden_states: torch.Tensor):
self.extend_pe(hidden_states)
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
relative_position_embeddings = self.pe[:, start_idx:end_idx]
return relative_position_embeddings
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states, norm_hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states)
return hidden_states
class Wav2Vec2ConformerConvolutionModule(nn.Module):
"""Convolution block used in the conformer block"""
def __init__(self, config):
super().__init__()
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.pointwise_conv1 = torch.nn.Conv1d(
config.hidden_size,
2 * config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.glu = torch.nn.GLU(dim=1)
self.depthwise_conv = torch.nn.Conv1d(
config.hidden_size,
config.hidden_size,
config.conv_depthwise_kernel_size,
stride=1,
padding=(config.conv_depthwise_kernel_size - 1) // 2,
groups=config.hidden_size,
bias=False,
)
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
self.pointwise_conv2 = torch.nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
# exchange the temporal dimension and the feature dimension
hidden_states = hidden_states.transpose(1, 2)
# GLU mechanism
# => (batch, 2*channel, dim)
hidden_states = self.pointwise_conv1(hidden_states)
# => (batch, channel, dim)
hidden_states = self.glu(hidden_states)
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.pointwise_conv2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2ConformerSelfAttention(nn.Module):
"""Construct an Wav2Vec2ConformerSelfAttention object.
Can be enhanced with rotary or relative position embeddings.
"""
def __init__(self, config):
super().__init__()
self.head_size = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.position_embeddings_type = config.position_embeddings_type
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(p=config.attention_dropout)
if self.position_embeddings_type == "relative":
# linear transformation for positional encoding
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# self-attention mechanism
batch_size, sequence_length, hidden_size = hidden_states.size()
# make sure query/key states can be != value states
query_key_states = hidden_states
value_states = hidden_states
if self.position_embeddings_type == "rotary":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
)
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
# project query_key_states and value_states
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.position_embeddings_type == "relative":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
" 'relative'"
)
# apply relative_position_embeddings to qk scores
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
scores = self._apply_relative_embeddings(
query=query, key=key, relative_position_embeddings=relative_position_embeddings
)
else:
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
# apply attention_mask if necessary
if attention_mask is not None:
scores = scores + attention_mask
# => (batch, head, time1, time2)
probs = torch.softmax(scores, dim=-1)
probs = self.dropout(probs)
# => (batch, head, time1, d_k)
hidden_states = torch.matmul(probs, value)
# => (batch, time1, hidden_size)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
hidden_states = self.linear_out(hidden_states)
return hidden_states, probs
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
batch_size, sequence_length, hidden_size = hidden_states.size()
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
cos = relative_position_embeddings[0, :sequence_length, ...]
sin = relative_position_embeddings[1, :sequence_length, ...]
# rotate hidden_states with rotary embeddings
hidden_states = hidden_states.transpose(0, 1)
rotated_states_begin = hidden_states[..., : self.head_size // 2]
rotated_states_end = hidden_states[..., self.head_size // 2 :]
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
hidden_states = (hidden_states * cos) + (rotated_states * sin)
hidden_states = hidden_states.transpose(0, 1)
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
return hidden_states
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
# 1. project positional embeddings
# => (batch, head, 2*time1-1, d_k)
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
# 2. Add bias to query
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
# 3. attention score: first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# => (batch, head, time1, time2)
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
# 4. then compute matrix b and matrix d
# => (batch, head, time1, 2*time1-1)
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
# 5. shift matrix b and matrix d
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
# 6. sum matrices
# => (batch, head, time1, time2)
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
return scores
class Wav2Vec2ConformerEncoderLayer(nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
dropout = config.attention_dropout
# Feed-forward 1
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
# Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.self_attn_dropout = torch.nn.Dropout(dropout)
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
# Conformer Convolution
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
# Feed-forward 2
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
hidden_states = hidden_states
# 1. Feed-Forward 1 layer
residual = hidden_states
hidden_states = self.ffn1_layer_norm(hidden_states)
hidden_states = self.ffn1(hidden_states)
hidden_states = hidden_states * 0.5 + residual
residual = hidden_states
# 2. Self-Attention layer
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weigts = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_dropout(hidden_states)
hidden_states = hidden_states + residual
# 3. Convolutional Layer
residual = hidden_states
hidden_states = self.conv_module(hidden_states)
hidden_states = residual + hidden_states
# 4. Feed-Forward 2 Layer
residual = hidden_states
hidden_states = self.ffn2_layer_norm(hidden_states)
hidden_states = self.ffn2(hidden_states)
hidden_states = hidden_states * 0.5 + residual
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weigts
class Wav2Vec2ConformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.position_embeddings_type == "relative":
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
elif config.position_embeddings_type == "rotary":
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
else:
self.embed_positions = None
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
hidden_states = self.dropout(hidden_states)
if self.embed_positions is not None:
relative_position_embeddings = self.embed_positions(hidden_states)
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
relative_position_embeddings,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
def __init__(self, config):
super().__init__()
self.num_groups = config.num_codevector_groups
self.num_vars = config.num_codevectors_per_group
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {config.codevector_dim} must be divisible "
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states, mask_time_indices=None):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapter(nn.Module):
def __init__(self, config):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
def forward(self, hidden_states):
# down project hidden_states if necessary
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
for layer in self.layers:
layerdrop_prob = np.random.random()
if not self.training or (layerdrop_prob > self.layerdrop):
hidden_states = layer(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2ConformerConfig
base_model_prefix = "wav2vec2_conformer"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
if hasattr(module, "pos_bias_u"):
nn.init.xavier_uniform_(module.pos_bias_u)
if hasattr(module, "pos_bias_v"):
nn.init.xavier_uniform_(module.pos_bias_v)
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
module.gradient_checkpointing = value
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, such as
[wav2vec2_conformer-base](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos),
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
that these models also yield slightly different results depending on whether `input_values` is padded or
not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
def __init__(self, config: Wav2Vec2ConformerConfig):
super().__init__(config)
self.config = config
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
self.encoder = Wav2Vec2ConformerEncoder(config)
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
)
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config: Wav2Vec2ConformerConfig):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
@staticmethod
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 0.1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
target_features
)
# apply temperature
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.BoolTensor] = None,
sampled_negative_indices: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
Required input for pre-training.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
>>> # show that cosine similarity is much higher than random
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
tensor(True)
>>> # for contrastive loss training model should be put into train mode
>>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if mask_time_indices is not None:
mask_time_indices = mask_time_indices.to(torch.bool)
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
return_dict=return_dict,
)
# 1. project all transformed features (including masked) to final vq dim
transformer_features = self.project_hid(outputs[0])
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices=mask_time_indices
)
quantized_features = self.project_q(quantized_features)
loss = contrastive_loss = diversity_loss = None
if sampled_negative_indices is not None:
batch_size, sequence_length, hidden_size = quantized_features.shape
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
# sample negative quantized vectors BTC => (BxT)C
negative_quantized_features = quantized_features.view(-1, hidden_size)[
sampled_negative_indices.long().view(-1)
]
negative_quantized_features = negative_quantized_features.view(
batch_size, sequence_length, -1, hidden_size
).permute(2, 0, 1, 3)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
logits = self.compute_contrastive_logits(
quantized_features[None, :],
negative_quantized_features,
transformer_features,
self.config.contrastive_logits_temperature,
)
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
if not return_dict:
if loss is not None:
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return Wav2Vec2ConformerForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
contrastive_loss=contrastive_loss,
diversity_loss=diversity_loss,
)
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
" vocabulary size of the language model head. Please instantiate the model as follows:"
" `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
" your model's configuration."
)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
tasks like SUPERB Keyword Spotting.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of Wav2Vec2Conformer adapters"
" (config.add_adapter=True)"
)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_conformer.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters"
" (config.add_adapter=True)"
)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_conformer.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
self.loss = nn.CrossEntropyLoss()
def forward(self, hidden_states, labels):
labels = labels.flatten()
weight = nn.functional.normalize(self.weight, dim=0)
hidden_states = nn.functional.normalize(hidden_states, dim=1)
cos_theta = torch.mm(hidden_states, weight)
psi = cos_theta - self.margin
onehot = nn.functional.one_hot(labels, self.num_labels)
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
loss = self.loss(logits, labels)
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
self.out_conv_dim = config.tdnn_dim[layer_id]
self.kernel_size = config.tdnn_kernel[layer_id]
self.dilation = config.tdnn_dilation[layer_id]
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()
def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
self.tdnn = nn.ModuleList(tdnn_layers)
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2_conformer.parameters():
param.requires_grad = False
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size in self.config.tdnn_kernel:
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
return input_lengths
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, XVectorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
for tdnn_layer in self.tdnn:
hidden_states = tdnn_layer(hidden_states)
# Statistic Pooling
if attention_mask is None:
mean_features = hidden_states.mean(dim=1)
std_features = hidden_states.std(dim=1)
else:
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
mean_features = []
std_features = []
for i, length in enumerate(tdnn_output_lengths):
mean_features.append(hidden_states[i, :length].mean(dim=0))
std_features.append(hidden_states[i, :length].std(dim=0))
mean_features = torch.stack(mean_features)
std_features = torch.stack(std_features)
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
output_embeddings = self.feature_extractor(statistic_pooling)
logits = self.classifier(output_embeddings)
loss = None
if labels is not None:
loss = self.objective(logits, labels)
if not return_dict:
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return XVectorOutput(
loss=loss,
logits=logits,
embeddings=output_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -77,13 +77,13 @@ class WavLMConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
......@@ -146,13 +146,13 @@ class WavLMConfig(PretrainedConfig):
instance of [`WavLMForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
......
......@@ -16,7 +16,6 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
......@@ -28,16 +27,17 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_wavlm import WavLMConfig
......@@ -80,67 +80,6 @@ WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class WavLMBaseModelOutput(ModelOutput):
"""
Output type of [`WavLMBaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ForXVector`].
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
......@@ -1184,7 +1123,7 @@ WAVLM_INPUTS_DOCSTRING = r"""
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
class WavLMModel(WavLMPreTrainedModel):
def __init__(self, config: WavLMConfig):
super().__init__(config)
......@@ -1275,7 +1214,7 @@ class WavLMModel(WavLMPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=WavLMBaseModelOutput,
output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -1288,7 +1227,7 @@ class WavLMModel(WavLMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, WavLMBaseModelOutput]:
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1325,7 +1264,7 @@ class WavLMModel(WavLMPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return WavLMBaseModelOutput(
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
......
......@@ -4440,6 +4440,58 @@ class Wav2Vec2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Wav2Vec2ConformerForAudioFrameClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerForCTC(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerForPreTraining(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerForXVector(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ConformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
import math
import unittest
import numpy as np
from datasets import load_dataset
from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
if is_torch_available():
import torch
from transformers import (
Wav2Vec2ConformerForAudioFrameClassification,
Wav2Vec2ConformerForCTC,
Wav2Vec2ConformerForPreTraining,
Wav2Vec2ConformerForSequenceClassification,
Wav2Vec2ConformerForXVector,
Wav2Vec2ConformerModel,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerGumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
)
class Wav2Vec2ConformerModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024, # speech is longer
is_training=False,
hidden_size=16,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(32, 32, 32),
conv_stride=(4, 4, 4),
conv_kernel=(8, 8, 8),
conv_bias=False,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=4,
num_attention_heads=2,
hidden_dropout_prob=0.1,
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
num_adapter_layers=1,
adapter_stride=2,
tdnn_dim=(32, 32),
tdnn_kernel=(5, 3),
tdnn_dilation=(1, 2),
xvector_output_dim=32,
position_embeddings_type="relative",
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = conv_dim
self.conv_stride = conv_stride
self.conv_kernel = conv_kernel
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.num_adapter_layers = num_adapter_layers
self.adapter_stride = adapter_stride
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
self.position_embeddings_type = position_embeddings_type
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
output_seq_length = (output_seq_length - (kernel - 1)) / stride
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
def prepare_config_and_inputs(self, position_embeddings_type="relative"):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config(position_embeddings_type=position_embeddings_type)
return config, input_values, attention_mask
def get_config(self, position_embeddings_type="relative"):
return Wav2Vec2ConformerConfig(
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
feat_extract_activation=self.feat_extract_activation,
conv_dim=self.conv_dim,
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
num_adapter_layers=self.num_adapter_layers,
adapter_stride=self.adapter_stride,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
position_embeddings_type=position_embeddings_type,
)
def create_and_check_model(self, config, input_values, attention_mask):
model = Wav2Vec2ConformerModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter(self, config, input_values, attention_mask):
config.add_adapter = True
model = Wav2Vec2ConformerModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter_for_ctc(self, config, input_values, attention_mask):
config.add_adapter = True
config.output_hidden_size = 2 * config.hidden_size
model = Wav2Vec2ConformerForCTC(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
)
def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
config.add_adapter = True
config.output_hidden_size = 8
model = Wav2Vec2ConformerModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2ConformerModel(config=config)
model.to(torch_device)
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
for i in range(input_values.shape[0]):
input_slice = input_values[i : i + 1, : input_lengths[i]]
output = model(input_slice).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_values, *args):
model = Wav2Vec2ConformerForCTC(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float))
def check_seq_classifier_loss(self, config, input_values, *args):
model = Wav2Vec2ConformerForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_values, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ConformerForCTC(config=config)
model.to(torch_device)
model.train()
# freeze feature encoder
model.freeze_feature_encoder()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
if max_length_labels[i] < labels.shape[-1]:
# it's important that we make sure that target lenghts are at least
# one shorter than logit lenghts to prevent -inf
labels[i, max_length_labels[i] - 1 :] = -100
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_seq_classifier_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ConformerForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_xvector_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ConformerForXVector(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = Wav2Vec2ConformerForCTC(config)
model.to(torch_device)
model.train()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
with self.parent.assertRaises(ValueError):
model(input_values, labels=labels)
def prepare_config_and_inputs_for_common(self):
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@require_torch
class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Wav2Vec2ConformerForCTC,
Wav2Vec2ConformerModel,
Wav2Vec2ConformerForSequenceClassification,
Wav2Vec2ConformerForPreTraining,
Wav2Vec2ConformerForAudioFrameClassification,
Wav2Vec2ConformerForXVector,
)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
test_torchscript = False
def setUp(self):
self.model_tester = Wav2Vec2ConformerModelTester(self)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2ConformerConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_no_rel_pos(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type=None)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
def test_model_with_adapter_for_ctc(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
def test_model_with_adapter_proj_dim(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_xvector_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_xvector_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Wav2Vec2Conformer has no inputs_embeds
def test_inputs_embeds(self):
pass
# `input_ids` is renamed to `input_values`
def test_forward_signature(self):
pass
# Wav2Vec2Conformer cannot resize token embeddings
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2Conformer has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_values = inputs_dict["input_values"]
input_lengths = torch.tensor(
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"pos_bias_v",
"pos_bias_u",
"pointwise_conv1",
"pointwise_conv2",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
def test_mask_feature_prob_ctc(self):
model = Wav2Vec2ConformerForCTC.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2-conformer", mask_feature_prob=0.2, mask_feature_length=2
)
model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
batch = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
self.assertEqual(logits.shape, (4, 1498, 32))
def test_mask_time_prob_ctc(self):
model = Wav2Vec2ConformerForCTC.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2-conformer", mask_time_prob=0.2, mask_time_length=2
)
model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
batch = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
self.assertEqual(logits.shape, (4, 1498, 32))
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2ConformerModel.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
self.assertIsNotNone(model)
@require_torch
class Wav2Vec2ConformerUtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_low_prob(self):
# with these settings num_masked_spans=0.5, which means probabilistic rounding
# ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
# the other 5 out of 10, cases num_masked_spans=1
n_trials = 100
batch_size = 4
sequence_length = 100
mask_prob = 0.05
mask_length = 10
count_dimensions_masked = 0
count_dimensions_not_masked = 0
for _ in range(n_trials):
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
num_masks = torch.sum(mask).item()
if num_masks > 0:
count_dimensions_masked += 1
else:
count_dimensions_not_masked += 1
# as we test for at least 10 masked dimension and at least
# 10 non-masked dimension, this test could fail with probability:
# P(100 coin flips, at most 9 heads) = 1.66e-18
self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
mask = torch.from_numpy(mask).to(torch_device)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_mask_indices_short_audio(self):
batch_size = 4
sequence_length = 100
mask_prob = 0.05
mask_length = 10
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
# force one example to be heavily padded
attention_mask[0, 5:] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
)
# make sure that non-padded examples cannot be padded
self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
def test_compute_perplexity(self):
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
mask[0] = 0
ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# sample negative indices
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue((negatives >= 0).all().item())
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch
@slow
class Wav2Vec2ConformerModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
speech_samples = speech_samples[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_inference_ctc_normal_batched_rel_pos(self):
model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large-960h-ft")
model.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-conformer-rel-pos-large-960h-ft", do_lower_case=True
)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loincloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_ctc_normal_batched_rope(self):
model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rope-large-960h-ft")
model.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-conformer-rope-large-960h-ft", do_lower_case=True
)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_pretrained(self):
model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-conformer-rel-pos-large", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
features_shape = (batch_size, feature_seq_length)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
min_masks=2,
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2ConformerConfig.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
model_rand = Wav2Vec2ConformerForPreTraining(config).to(torch_device).eval()
with torch.no_grad():
outputs_rand = model_rand(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = torch.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2_conformer model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2_conformer model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
......@@ -61,6 +61,7 @@ src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/yolos/modeling_yolos.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment