Unverified Commit 700a748f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Add New Wav2Vec2 Translation (#14392)

* add new wav2vec2 translation

* correct

* up

* add tests

* correct end copy

* correct more

* up

* correct unispeech sat

* finish

* finalize

* finish

* up
parent b567510c
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2 checkpoint."""
import argparse
import fairseq
import torch
from torch import nn
from transformers import (
MBart50Tokenizer,
MBartConfig,
MBartForCausalLM,
SpeechEncoderDecoderConfig,
SpeechEncoderDecoderModel,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2Model,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]
def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights_wav2vec2(fairseq_model, hf_model):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.feature_extractor
adapter = hf_model.adapter
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
elif any(x in name for x in ["adaptor", "w2v_encoder.proj.", "w2v_proj_ln."]):
load_adapter(name, value, adapter, unused_weights)
is_used = True
else:
for key, mapped_key in MAPPING.items():
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
weight_type = "weight"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
unused_weights.append(name)
logger.warning(f"Unused weights: {unused_weights}")
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])
if type_id == 0:
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)
def load_adapter(full_name, value, adapter, unused_weights):
name = full_name.split("adaptor.")[-1]
items = name.split(".")
if items[1].isdigit():
layer_id = int(items[1])
else:
layer_id = None
if "adaptor" not in full_name:
if "proj_ln" in full_name:
# has to be layer norm
if "bias" in name:
assert (
value.shape == adapter.proj_layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.bias.data.shape} was found."
adapter.proj_layer_norm.bias.data = value
logger.info(f"Adapter proj layer norm bias was initialized from {full_name}.")
if "weight" in name:
assert (
value.shape == adapter.proj_layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.weight.data.shape} was found."
adapter.proj_layer_norm.weight.data = value
else:
# has to be projection layer
if "bias" in name:
assert (
value.shape == adapter.proj.bias.data.shape
), f"{full_name} has size {value.shape}, but {adapter.proj.bias.data.shape} was found."
adapter.proj.bias.data = value
logger.info(f"Adapter proj layer bias was initialized from {full_name}.")
if "weight" in name:
assert (
value.shape == adapter.proj.weight.data.shape
), f"{full_name} has size {value.shape}, but {adapter.proj.weight.data.shape} was found."
adapter.proj.weight.data = value
logger.info(f"Adapter proj layer weight was initialized from {full_name}.")
elif isinstance(layer_id, int):
if "bias" in name:
assert (
value.shape == adapter.layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.bias.data.shape} was found."
adapter.layers[layer_id].conv.bias.data = value
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == adapter.layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.weight.data.shape} was found."
adapter.layers[layer_id].conv.weight.data = value
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
else:
unused_weights.append(full_name)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
@torch.no_grad()
def convert_wav2vec2_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
dict_path,
config_yaml_path,
encoder_config_path,
decoder_config_path,
add_adapter,
adapter_kernel_size,
adapter_stride,
decoder_start_token_id,
encoder_output_dim,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
# load configs
encoder_config = Wav2Vec2Config.from_pretrained(
encoder_config_path,
add_adapter=True,
adapter_stride=adapter_stride,
adapter_kernel_size=adapter_kernel_size,
use_auth_token=True,
output_hidden_size=encoder_output_dim,
)
decoder_config = MBartConfig.from_pretrained(decoder_config_path)
# load model
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path],
arg_overrides={
"config_yaml": config_yaml_path,
"data": "/".join(dict_path.split("/")[:-1]),
"w2v_path": checkpoint_path,
"load_pretrained_decoder_from": None,
},
)
model = model[0].eval()
# load feature extractor
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_config_path, use_auth_token=True)
# set weights for wav2vec2 encoder
hf_encoder = Wav2Vec2Model(encoder_config)
recursively_load_weights_wav2vec2(model.encoder, hf_encoder)
# load decoder weights
hf_decoder = MBartForCausalLM(decoder_config)
missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False)
logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}")
logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}")
hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder)
hf_wav2vec.config.tie_word_embeddings = False
tokenizer = MBart50Tokenizer(dict_path)
tokenizer.save_pretrained(pytorch_dump_folder_path)
config = hf_wav2vec.config.to_dict()
config["pad_token_id"] = tokenizer.pad_token_id
config["bos_token_id"] = tokenizer.bos_token_id
config["eos_token_id"] = tokenizer.eos_token_id
config["tokenizer_class"] = "mbart50"
config["feature_extractor_type"] = "wav2vec2"
config["decoder_start_token_id"] = tokenizer.eos_token_id
config["forced_bos_token_id"] = 250004
config["forced_eos_token_id"] = tokenizer.eos_token_id
hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_yaml_path", default=None, type=str, help="Path to yaml file of fine-tuned model")
parser.add_argument(
"--encoder_config_path",
default="facebook/wav2vec2-xls-r-1b",
type=str,
help="Path to hf encoder wav2vec2 checkpoint config",
)
parser.add_argument(
"--decoder_config_path",
default="facebook/mbart-large-50-one-to-many-mmt",
type=str,
help="Path to hf decoder checkpoint config",
)
parser.add_argument("--add_adapter", default=True, type=bool, help="whethere to add model adapter layers")
parser.add_argument("--adapter_stride", default=2, type=int, help="stride of adapter layers")
parser.add_argument("--adapter_kernel_size", default=3, type=int, help="kernel size of adapter layers")
parser.add_argument("--encoder_output_dim", default=1024, type=int, help="encoder output dim")
parser.add_argument("--start_token_id", default=250004, type=int, help="`decoder_start_token_id` of model config")
args = parser.parse_args()
convert_wav2vec2_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.dict_path,
args.config_yaml_path,
encoder_config_path=args.encoder_config_path,
decoder_config_path=args.decoder_config_path,
add_adapter=args.add_adapter,
adapter_kernel_size=args.adapter_kernel_size,
adapter_stride=args.adapter_stride,
decoder_start_token_id=args.start_token_id,
encoder_output_dim=args.encoder_output_dim,
)
......@@ -223,7 +223,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
# get encoder output hidden size
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
if self.encoder_output_dim != self.decoder.config.hidden_size:
# encoder outputs might need to be projected to different dimension for decoder
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
......@@ -471,7 +473,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
encoder_hidden_states = encoder_outputs[0]
# project encoder_hidden_states
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
if self.encoder_output_dim != self.decoder.config.hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# compute correct encoder attention mask
......
......@@ -892,7 +892,6 @@ class UniSpeechGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeech, wav2vec2->unispeech
class UniSpeechPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......@@ -1032,7 +1031,6 @@ UNISPEECH_INPUTS_DOCSTRING = r"""
"The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
class UniSpeechModel(UniSpeechPreTrainedModel):
def __init__(self, config: UniSpeechConfig):
super().__init__(config)
......@@ -1049,6 +1047,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
......
......@@ -893,7 +893,6 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat
class UniSpeechSatPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......@@ -1033,7 +1032,6 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
"The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_SAT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
def __init__(self, config: UniSpeechSatConfig):
super().__init__(config)
......@@ -1050,6 +1048,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
......
......@@ -140,6 +140,19 @@ class Wav2Vec2Config(PretrainedConfig):
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
adapter_kernel_size (:obj:`int`, `optional`, defaults to 3):
Kernel size of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``.
adapter_stride (:obj:`int`, `optional`, defaults to 2):
Stride of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``.
num_adapter_layers (:obj:`int`, `optional`, defaults to 3):
Number of convolutional layers that should be used in the adapter network. Only relevant if ``add_adapter
is True``.
output_hidden_size (:obj:`int`, `optional`):
Dimensionality of the encoder output layer. If not defined, this defaults to `hidden-size`. Only relevant
if ``add_adapter is True``.
Example::
......@@ -201,6 +214,11 @@ class Wav2Vec2Config(PretrainedConfig):
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
add_adapter=False,
adapter_kernel_size=3,
adapter_stride=2,
num_adapter_layers=3,
output_hidden_size=None,
**kwargs
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
......@@ -263,3 +281,10 @@ class Wav2Vec2Config(PretrainedConfig):
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
# adapter
self.add_adapter = add_adapter
self.adapter_kernel_size = adapter_kernel_size
self.adapter_stride = adapter_stride
self.num_adapter_layers = num_adapter_layers
self.output_hidden_size = output_hidden_size or hidden_size
......@@ -935,6 +935,55 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
class Wav2Vec2Adapter(nn.Module):
def __init__(self, config):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
def forward(self, hidden_states):
# down project hidden_states if necessary
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
for layer in self.layers:
layerdrop_prob = np.random.random()
if not self.training or (layerdrop_prob > self.layerdrop):
hidden_states = layer(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2AdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class Wav2Vec2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......@@ -979,11 +1028,15 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
......@@ -992,13 +1045,22 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
......@@ -1088,6 +1150,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
else:
self.encoder = Wav2Vec2Encoder(config)
self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
self.init_weights()
def _mask_hidden_states(
......@@ -1163,7 +1227,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
......@@ -1180,6 +1246,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
......@@ -1328,7 +1397,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices=mask_time_indices
......
......@@ -82,6 +82,8 @@ class Wav2Vec2ModelTester:
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
num_adapter_layers=1,
adapter_stride=2,
scope=None,
):
self.parent = parent
......@@ -107,6 +109,8 @@ class Wav2Vec2ModelTester:
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.num_adapter_layers = num_adapter_layers
self.adapter_stride = adapter_stride
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
......@@ -117,6 +121,8 @@ class Wav2Vec2ModelTester:
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
......@@ -148,6 +154,8 @@ class Wav2Vec2ModelTester:
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
num_adapter_layers=self.num_adapter_layers,
adapter_stride=self.adapter_stride,
)
def create_and_check_model(self, config, input_values, attention_mask):
......@@ -159,6 +167,28 @@ class Wav2Vec2ModelTester:
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter(self, config, input_values, attention_mask):
config.add_adapter = True
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
)
def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
config.add_adapter = True
config.output_hidden_size = 8
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
......@@ -332,6 +362,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
def test_model_with_adapter_proj_dim(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
......@@ -544,6 +582,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
def test_model_with_adapter_proj_dim(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
......
......@@ -203,3 +203,41 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
data = f.read()
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
@slow
@require_torch
@require_torchaudio
@require_datasets
def test_xls_r_to_en(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-xls-r-1b-21-to-en",
feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en",
framework="pt",
)
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
@slow
@require_torch
@require_torchaudio
@require_datasets
def test_xls_r_from_en(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-xls-r-1b-en-to-15",
feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15",
framework="pt",
)
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment