Unverified Commit 0234de84 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Fine-Tuning for Wav2Vec2 (#10145)



* add encode labels function to tokenizer

* start adding finetuning

* init dropout

* upload

* correct convert script

* apply changes

* fix second typo

* make first dummy training run

* adapt convert script

* push confg for comparison

* remove conf

* finish training

* adapt data collator

* add research folder

* update according to fairseq feedback

* some minor corrections

* refactor masking indices a bit

* some minor changes

* clean tokenizer

* finish clean-up

* remove previous logic

* update run script

* correct training

* finish changes

* finish model

* correct bug

* fix training a bit more

* add some tests

* finish gradient checkpointing

* finish example

* correct gradient checkpointing

* improve tokenization method

* revert changes in tokenizer

* revert general change

* adapt fine-tuning

* update

* save intermediate test

* Update README.md

* finish finetuning

* delete conversion script

* Update src/transformers/models/wav2vec2/configuration_wav2vec2.py

* Update src/transformers/models/wav2vec2/processing_wav2vec2.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* finish wav2vec2 script

* finish wav2vec2 fine-tuning

* finalize test

* correct test

* adapt tests

* finish

* remove test file
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 3c733f32
......@@ -62,7 +62,7 @@ Wav2Vec2Processor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2Processor
:members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
Wav2Vec2Model
......
## Fine-tuning Wav2Vec2
The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
This finetuning script can also be run as a google colab [TODO: here]( ).
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-base-100h" \
--num_train_epochs="30" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--evaluation_strategy="steps" \
--save_total_limit="3" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-base" \
--fp16 \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="32" \
--group_by_length \
--freeze_feature_extractor
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-large-lv60-100h" \
--num_train_epochs="30" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--evaluation_strategy="steps" \
--save_total_limit="3" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-large-lv60" \
--fp16 \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="32" \
--group_by_length \
--freeze_feature_extractor
transformers
datasets
torch >= 1.5.0
jiwer
#!/usr/bin/env python3
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import datasets
import numpy as np
import torch
import torch.nn as nn
from packaging import version
import soundfile as sf
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
is_apex_available,
)
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
max_length_labels (:obj:`int`, `optional`):
Maximum length of the ``labels`` returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lenghts and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
max_length=self.max_length_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
class CTCTrainer(Trainer):
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
if model.module.config.ctc_loss_reduction == "mean":
loss = loss.mean()
elif model.module.config.ctc_loss_reduction == "sum":
loss = loss.sum() / (inputs["labels"] >= 0).sum()
else:
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
self.deepspeed.backward(loss)
else:
loss.backward()
return loss.detach()
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
train_dataset = datasets.load_dataset(
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
)
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
wer_metric = datasets.load_metric("wer")
def map_to_array(batch):
speech_array, sampling_rate = sf.read(batch["file"])
batch["speech"] = speech_array
batch["sampling_rate"] = sampling_rate
return batch
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
def prepare_dataset(batch):
# check that all files have the correct sampling rate
assert (
len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
with processor.as_target_processor():
batch["labels"] = processor(batch["text"]).input_ids
return batch
train_dataset = train_dataset.map(
prepare_dataset,
batch_size=training_args.per_device_train_batch_size,
batched=True,
num_proc=data_args.preprocessing_num_workers,
)
val_dataset = val_dataset.map(
prepare_dataset,
batch_size=training_args.per_device_train_batch_size,
batched=True,
num_proc=data_args.preprocessing_num_workers,
)
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = 0
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
if model_args.freeze_feature_extractor:
model.freeze_feature_extractor()
trainer = CTCTrainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=processor.feature_extractor,
)
trainer.train()
if __name__ == "__main__":
main()
......@@ -92,6 +92,33 @@ class Wav2Vec2Config(PretrainedConfig):
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
False`` corresponds to applying layer norm after the attention layer.
freeze_feat_extract_train (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to freeze the weights of the feature extractor when training.
apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
`SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
<https://arxiv.org/abs/1904.08779>`__.
mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_time_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the time axis.
mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the feature axis.
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`.
ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -116,12 +143,15 @@ class Wav2Vec2Config(PretrainedConfig):
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
hidden_dropout=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.1,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(512, 512, 512, 512, 512, 512, 512),
conv_stride=(5, 2, 2, 2, 2, 2, 2),
......@@ -130,6 +160,15 @@ class Wav2Vec2Config(PretrainedConfig):
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
do_stable_layer_norm=False,
freeze_feat_extract_train=True,
apply_spec_augment=True,
mask_time_prob=0.05,
mask_time_length=10,
mask_feature_prob=0.0,
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
......@@ -138,7 +177,6 @@ class Wav2Vec2Config(PretrainedConfig):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = list(conv_dim)
self.conv_stride = list(conv_stride)
......@@ -151,12 +189,18 @@ class Wav2Vec2Config(PretrainedConfig):
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.feat_proj_dropout = feat_proj_dropout
self.final_dropout = final_dropout
self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.freeze_feat_extract_train = freeze_feat_extract_train
self.gradient_checkpointing = gradient_checkpointing
if (
(len(self.conv_stride) != self.num_feat_extract_layers)
......@@ -169,3 +213,14 @@ class Wav2Vec2Config(PretrainedConfig):
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
......@@ -20,26 +20,27 @@ import argparse
import fairseq
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "wav2vec2.feature_projection.projection",
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
......@@ -47,7 +48,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_shape = getattr(hf_pointer, weight_type).shape
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
......@@ -59,26 +64,32 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model):
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
hf_model.wav2vec2.feature_extractor,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
if key in name:
is_used = True
if "*" in mapped_key:
......@@ -92,6 +103,8 @@ def recursively_load_weights(fairseq_model, hf_model):
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
......@@ -137,18 +150,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
@torch.no_grad()
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
def convert_wav2vec2_checkpoint(
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
if config_path is not None:
config = Wav2Vec2Config.from_pretrained(config_path)
else:
config = Wav2Vec2Config()
if is_finetuned:
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2Model(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec)
recursively_load_weights(model, hf_wav2vec, is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
......@@ -158,5 +185,11 @@ if __name__ == "__main__":
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)
args = parser.parse_args()
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)
convert_wav2vec2_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
)
......@@ -115,6 +115,16 @@ class Wav2Vec2Processor:
"""
return self.current_processor(*args, **kwargs)
def pad(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the
above two methods for more information.
"""
return self.current_processor.pad(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Wav2Vec2CTCTokenizer's
......
......@@ -509,11 +509,18 @@ class Trainer:
# Build the sampler.
if self.args.group_by_length:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if num_processes <= 1:
return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size)
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
)
else:
return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
self.train_dataset,
self.args.train_batch_size,
num_replicas=num_processes,
rank=process_index,
model_input_name=model_input_name,
)
else:
......
......@@ -452,16 +452,23 @@ class LengthGroupedSampler(Sampler):
keeping a bit of randomness.
"""
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
def __init__(
self,
dataset: Dataset,
batch_size: int,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
):
self.dataset = dataset
self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"'input_ids' key."
f"'{self.model_input_name}' key."
)
lengths = [len(feature["input_ids"]) for feature in dataset]
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
def __len__(self):
......@@ -487,6 +494,7 @@ class DistributedLengthGroupedSampler(DistributedSampler):
seed: int = 0,
drop_last: bool = False,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
):
if num_replicas is None:
if not dist.is_available():
......@@ -513,14 +521,15 @@ class DistributedLengthGroupedSampler(DistributedSampler):
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"'input_ids' key."
f"'{self.model_input_name}' key."
)
lengths = [len(feature["input_ids"]) for feature in dataset]
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
def __iter__(self) -> Iterator:
......
......@@ -735,6 +735,8 @@ class ModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
print(outputs)
output = outputs[0]
if config.is_encoder_decoder:
......
......@@ -18,7 +18,7 @@
import math
import unittest
from tests.test_modeling_common import floats_tensor, random_attention_mask
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
......@@ -30,6 +30,7 @@ if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
class Wav2Vec2ModelTester:
......@@ -128,9 +129,7 @@ class Wav2Vec2ModelTester:
)
def create_and_check_batch_inference(self, config, input_values, *args):
# Not sure how to make this test pass at the moment. Batched input yields
# same results as official fairseq implementation, but gives different results
# depending on whether batched input is used or not
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2Model(config=config)
model.to(torch_device)
......@@ -155,6 +154,62 @@ class Wav2Vec2ModelTester:
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_values, *args):
model = Wav2Vec2ForCTC(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3)
def check_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForCTC(config=config)
model.to(torch_device)
model.train()
# freeze feature encoder
model.freeze_feature_extractor()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
if max_length_labels[i] < labels.shape[-1]:
# it's important that we make sure that target lenghts are at least
# one shorter than logit lenghts to prevent -inf
labels[i, max_length_labels[i] - 1 :] = -100
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def prepare_config_and_inputs_for_common(self):
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
......@@ -165,6 +220,7 @@ class Wav2Vec2ModelTester:
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
......@@ -186,6 +242,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
......@@ -205,6 +269,46 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_values = inputs_dict["input_values"]
input_lengths = torch.tensor(
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -213,7 +317,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
if "conv.weight" in name or "masked_spec_embed" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
......@@ -233,7 +337,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
......@@ -255,6 +359,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
......@@ -274,6 +386,46 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_values = inputs_dict["input_values"]
input_lengths = torch.tensor(
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -282,7 +434,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
if "conv.weight" in name or "masked_spec_embed" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
......@@ -300,6 +452,59 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
@require_torch
class Wav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
)
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
),
)
@require_torch
@slow
@require_datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment