Unverified Commit d472bd7b authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Wav2Vec2 Pretraining (#11306)



* Working quantizer forward

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Remove custom outputs from the shared ones

* correct conversion

* correct bug

* add first pretrain script

* save intermediate

* static shapes

* save intermediate

* finish first pretrain script version

* more refactor

* remove wanddb

* refactor more

* improve test

* correct perplexity compute bug

* finish model implementation

* add to docs

* finish docs

* finish pretraining script

* finish pretraining script

* remove wandb

* finish PR for merge

* finish config

* finish

* make deepspeed work

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

* fix flaky test
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent b1a8aa94
...@@ -79,3 +79,10 @@ Wav2Vec2ForCTC ...@@ -79,3 +79,10 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward
...@@ -184,3 +184,35 @@ run_asr.py \ ...@@ -184,3 +184,35 @@ run_asr.py \
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \ --preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
--deepspeed ds_config_wav2vec2_zero3.json --deepspeed ds_config_wav2vec2_zero3.json
``` ```
### Pretraining Wav2Vec2
The `run_pretrain.py` script allows one to pretrain a Wav2Vec2 model from scratch using Wav2Vec2's contrastive loss objective (see official [paper](https://arxiv.org/abs/2006.11477) for more information).
It is recommended to pre-train Wav2Vec2 with Trainer + Deepspeed (please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration) for more information).
Here is an example of how you can use DeepSpeed ZeRO-2 to pretrain a small Wav2Vec2 model:
```
PYTHONPATH=../../../src deepspeed --num_gpus 2 run_pretrain.py \
--output_dir="./wav2vec2-base-libri-100h" \
--num_train_epochs="3" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--gradient_accumulation_steps="2" \
--save_total_limit="3" \
--save_steps="500" \
--logging_steps="10" \
--learning_rate="5e-4" \
--weight_decay="0.01" \
--warmup_steps="3000" \
--model_name_or_path="patrickvonplaten/wav2vec2-base-libri-100h" \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="4" \
--max_duration_in_seconds="10.0" \
--group_by_length \
--verbose_logging \
--fp16 \
--deepspeed ds_config_wav2vec2_zero2.json \
```
#!/usr/bin/env python3
import logging
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset
from packaging import version
import librosa
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForPreTraining,
is_apex_available,
trainer_utils,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
verbose_logging: Optional[bool] = field(
default=False,
metadata={"help": "Whether to log verbose messages or not."},
)
max_gumbel_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
)
min_gumbel_temperature: Optional[float] = field(
default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."}
)
gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
)
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logging_level = logging.WARNING
if model_args.verbose_logging:
logging_level = logging.DEBUG
elif trainer_utils.is_main_process(training_args.local_rank):
logging_level = logging.INFO
logger.setLevel(logging_level)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
speech_file_column: Optional[str] = field(
default="file",
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_duration_in_seconds: Optional[float] = field(
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
)
@dataclass
class DataCollatorForWav2Vec2Pretraining:
"""
Data collator that will dynamically pad the inputs received and prepare masked indices
for self-supervised pretraining.
Args:
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
The Wav2Vec2 model used for pretraining. The data collator needs to have access
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
model: Wav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
max_length=self.max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
device=batch["input_values"].device,
min_masks=2,
)
return batch
class Wav2Vec2PreTrainer(Trainer):
"""
Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training.
"""
def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.num_update_step = 0
self.max_gumbel_temp = max_gumbel_temp
self.min_gumbel_temp = min_gumbel_temp
self.gumbel_temp_decay = gumbel_temp_decay
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1 or self.deepspeed:
if model.module.config.ctc_loss_reduction == "mean":
loss = loss.mean()
elif model.module.config.ctc_loss_reduction == "sum":
loss = loss.sum() / (inputs["mask_time_indices"]).sum()
else:
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
self.deepspeed.backward(loss)
else:
loss.backward()
self.num_update_step += 1
# make sure gumbel softmax temperature is decayed
if self.args.n_gpu > 1 or self.deepspeed:
model.module.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
else:
model.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
return loss.detach()
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
configure_logger(model_args, training_args)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split="validation",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}",
cache_dir=model_args.cache_dir,
)
# only normalized-inputs-training is supported
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
)
def prepare_dataset(batch):
# check that all files have the correct sampling rate
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
return batch
# load audio files into numpy arrays
vectorized_datasets = datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
)
# filter audio files that are too long
vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
)
def normalize(batch):
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
# normalize and transform to `BatchFeatures`
vectorized_datasets = vectorized_datasets.map(
normalize,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=vectorized_datasets["train"].column_names,
)
# pretraining is only supported for "newer" stable layer norm architecture
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
config = Wav2Vec2Config.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
gradient_checkpointing=model_args.gradient_checkpointing,
)
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
model = Wav2Vec2ForPreTraining(config)
data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor)
trainer = Wav2Vec2PreTrainer(
model=model,
data_collator=data_collator,
args=training_args,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["validation"],
tokenizer=feature_extractor,
max_gumbel_temp=model_args.max_gumbel_temperature,
min_gumbel_temp=model_args.min_gumbel_temperature,
gumbel_temp_decay=model_args.gumbel_temperature_decay,
)
trainer.train()
if __name__ == "__main__":
main()
...@@ -1046,6 +1046,7 @@ if is_torch_available(): ...@@ -1046,6 +1046,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -2411,6 +2412,7 @@ if TYPE_CHECKING: ...@@ -2411,6 +2412,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import ( ...@@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import (
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
from ..vit.modeling_vit import ViTForImageClassification, ViTModel from ..vit.modeling_vit import ViTForImageClassification, ViTModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
...@@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(IBertConfig, IBertForMaskedLM), (IBertConfig, IBertForMaskedLM),
(DebertaConfig, DebertaForMaskedLM), (DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForPreTraining),
] ]
) )
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -48,6 +49,7 @@ if TYPE_CHECKING: ...@@ -48,6 +49,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig):
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for quantized feature extractor states.
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
...@@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig):
masked along the time axis. This is only relevant if ``apply_spec_augment is True``. masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_feature_length (:obj:`int`, `optional`, defaults to 10): mask_feature_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the feature axis. Length of vector span along the feature axis.
num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320):
Number of entries in each quantization codebook (group).
num_codevector_groups (:obj:`int`, `optional`, defaults to 2):
Number of codevector groups for product codevector quantization.
contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1):
The temperature `kappa` in the contrastive loss.
feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for the output of the feature extractor that's used by the quantizer.
num_negatives (:obj:`int`, `optional`, defaults to 100):
Number of negative samples for the contrastive loss.
codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the quantized feature vectors.
proj_codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the final projection of both the quantized and the transformer features.
diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1):
The weight of the codebook diversity loss component.
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`. instance of :class:`~transformers.Wav2Vec2ForCTC`.
...@@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig):
activation_dropout=0.1, activation_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
feat_proj_dropout=0.1, feat_proj_dropout=0.1,
feat_quantizer_dropout=0.0,
final_dropout=0.1, final_dropout=0.1,
layerdrop=0.1, layerdrop=0.1,
initializer_range=0.02, initializer_range=0.02,
...@@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig):
mask_time_length=10, mask_time_length=10,
mask_feature_prob=0.0, mask_feature_prob=0.0,
mask_feature_length=10, mask_feature_length=10,
num_codevectors_per_group=320,
num_codevector_groups=2,
contrastive_logits_temperature=0.1,
num_negatives=100,
codevector_dim=256,
proj_codevector_dim=256,
diversity_loss_weight=0.1,
ctc_loss_reduction="sum", ctc_loss_reduction="sum",
ctc_zero_infinity=False, ctc_zero_infinity=False,
gradient_checkpointing=False, gradient_checkpointing=False,
...@@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig):
self.mask_feature_prob = mask_feature_prob self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length self.mask_feature_length = mask_feature_length
# parameters for pretraining with codevector quantized representations
self.num_codevectors_per_group = num_codevectors_per_group
self.num_codevector_groups = num_codevector_groups
self.contrastive_logits_temperature = contrastive_logits_temperature
self.feat_quantizer_dropout = feat_quantizer_dropout
self.num_negatives = num_negatives
self.codevector_dim = codevector_dim
self.proj_codevector_dim = proj_codevector_dim
self.diversity_loss_weight = diversity_loss_weight
# ctc loss # ctc loss
self.ctc_loss_reduction = ctc_loss_reduction self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity self.ctc_zero_infinity = ctc_zero_infinity
...@@ -28,7 +28,7 @@ from transformers import ( ...@@ -28,7 +28,7 @@ from transformers import (
Wav2Vec2CTCTokenizer, Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2Model, Wav2Vec2ForPreTraining,
Wav2Vec2Processor, Wav2Vec2Processor,
logging, logging,
) )
...@@ -50,9 +50,20 @@ MAPPING = { ...@@ -50,9 +50,20 @@ MAPPING = {
"final_layer_norm": "encoder.layers.*.final_layer_norm", "final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm", "encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm", "w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head", "w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed", "mask_emb": "masked_spec_embed",
} }
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]
def set_recursively(hf_pointer, key, value, full_name, weight_type): def set_recursively(hf_pointer, key, value, full_name, weight_type):
...@@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): ...@@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model, is_finetuned): def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = [] unused_weights = []
fairseq_dict = fairseq_model.state_dict() fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor feature_extractor = hf_model.wav2vec2.feature_extractor
for name, value in fairseq_dict.items(): for name, value in fairseq_dict.items():
is_used = False is_used = False
...@@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): ...@@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
is_used = True is_used = True
else: else:
for key, mapped_key in MAPPING.items(): for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
is_used = True is_used = True
if "*" in mapped_key: if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2] layer_index = name.split(key)[0].split(".")[-2]
...@@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): ...@@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight_g" weight_type = "weight_g"
elif "weight_v" in name: elif "weight_v" in name:
weight_type = "weight_v" weight_type = "weight_v"
elif "weight" in name:
weight_type = "weight"
elif "bias" in name: elif "bias" in name:
weight_type = "bias" weight_type = "bias"
elif "weight" in name:
# TODO: don't match quantizer.weight_proj
weight_type = "weight"
else: else:
weight_type = None weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type) set_recursively(hf_model, mapped_key, value, name, weight_type)
...@@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint( ...@@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2ForCTC(config) hf_wav2vec = Wav2Vec2ForCTC(config)
else: else:
hf_wav2vec = Wav2Vec2Model(config) hf_wav2vec = Wav2Vec2ForPreTraining(config)
if is_finetuned: if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
...@@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint( ...@@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint(
model = model[0].eval() model = model[0].eval()
recursively_load_weights(model, hf_wav2vec, is_finetuned) recursively_load_weights(model, hf_wav2vec, not is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
......
...@@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM: ...@@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Wav2Vec2ForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2Model: class Wav2Vec2Model:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init ...@@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor from transformers import (
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
class Wav2Vec2ModelTester: class Wav2Vec2ModelTester:
...@@ -219,13 +227,7 @@ class Wav2Vec2ModelTester: ...@@ -219,13 +227,7 @@ class Wav2Vec2ModelTester:
@require_torch @require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
...@@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad: if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name: if any([x in name for x in uniform_init_parms]):
self.assertTrue( self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
...@@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3) module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None: if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3) module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3) module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
...@@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad: if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name: if any([x in name for x in uniform_init_parms]):
self.assertTrue( self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
...@@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3) module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None: if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3) module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3) module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
def test_model_for_pretraining(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Wav2Vec2ForPreTraining(config).to(torch_device)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase): ...@@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices_overlap(self): def test_compute_mask_indices_overlap(self):
batch_size = 4 batch_size = 4
sequence_length = 60 sequence_length = 80
mask_prob = 0.5 mask_prob = 0.5
mask_length = 4 mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
# because of overlap there is a range of possible masks # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1): for batch_sum in mask.sum(axis=-1):
self.assertIn( self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
int(batch_sum),
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), def test_compute_perplexity(self):
) probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
mask[0] = 0
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch @require_torch
@slow
@require_datasets @require_datasets
@require_soundfile @require_soundfile
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase): class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
from datasets import load_dataset from datasets import load_dataset
...@@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
"his instant panic was followed by a small sharp blow high on his chest", "his instant panic was followed by a small sharp blow high on his chest",
] ]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_integration(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
device=torch_device,
)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
def test_inference_pretrained(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
with torch.no_grad():
outputs_rand = model_rand(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = torch.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2 model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_loss_pretraining(self):
model = Wav2Vec2ForPreTraining.from_pretrained(
"patrickvonplaten/wav2vec2-base",
attention_dropout=0.0,
feat_proj_dropout=0.0,
hidden_dropout=0.0,
layerdrop=0.0,
)
model.to(torch_device).train()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment