Unverified Commit b6f332ec authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)

* Add hubert classifier + tests

* Add hubert classifier + tests

* Dummies for all classification tests

* Wav2Vec2 classifier + ER test

* Fix hubert integration tests

* Add hubert IC

* Pass tests for all classification tasks on Hubert

* Pass all tests + copies

* Move models to the SUPERB org
parent 2bef3433
...@@ -64,6 +64,14 @@ HubertForCTC ...@@ -64,6 +64,14 @@ HubertForCTC
.. autoclass:: transformers.HubertForCTC .. autoclass:: transformers.HubertForCTC
:members: forward :members: forward
HubertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.HubertForSequenceClassification
:members: forward
TFHubertModel TFHubertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -96,6 +96,14 @@ Wav2Vec2ForCTC ...@@ -96,6 +96,14 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
Wav2Vec2ForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForSequenceClassification
:members: forward
Wav2Vec2ForPreTraining Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -818,6 +818,7 @@ if is_torch_available(): ...@@ -818,6 +818,7 @@ if is_torch_available():
[ [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC", "HubertForCTC",
"HubertForSequenceClassification",
"HubertModel", "HubertModel",
"HubertPreTrainedModel", "HubertPreTrainedModel",
] ]
...@@ -1128,6 +1129,7 @@ if is_torch_available(): ...@@ -1128,6 +1129,7 @@ if is_torch_available():
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining", "Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -2424,6 +2426,7 @@ if TYPE_CHECKING: ...@@ -2424,6 +2426,7 @@ if TYPE_CHECKING:
from .models.hubert import ( from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC, HubertForCTC,
HubertForSequenceClassification,
HubertModel, HubertModel,
HubertPreTrainedModel, HubertPreTrainedModel,
) )
...@@ -2681,6 +2684,7 @@ if TYPE_CHECKING: ...@@ -2681,6 +2684,7 @@ if TYPE_CHECKING:
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining, Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["modeling_hubert"] = [ _import_structure["modeling_hubert"] = [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC", "HubertForCTC",
"HubertForSequenceClassification",
"HubertModel", "HubertModel",
"HubertPreTrainedModel", "HubertPreTrainedModel",
] ]
...@@ -48,6 +49,7 @@ if TYPE_CHECKING: ...@@ -48,6 +49,7 @@ if TYPE_CHECKING:
from .modeling_hubert import ( from .modeling_hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC, HubertForCTC,
HubertForSequenceClassification,
HubertModel, HubertModel,
HubertPreTrainedModel, HubertPreTrainedModel,
) )
......
...@@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig): ...@@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.HubertForCTC`. instance of :class:`~transformers.HubertForCTC`.
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of :class:`~transformers.HubertForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
...@@ -165,6 +170,8 @@ class HubertConfig(PretrainedConfig): ...@@ -165,6 +170,8 @@ class HubertConfig(PretrainedConfig):
mask_feature_length=10, mask_feature_length=10,
ctc_loss_reduction="sum", ctc_loss_reduction="sum",
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False, gradient_checkpointing=False,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
...@@ -197,6 +204,8 @@ class HubertConfig(PretrainedConfig): ...@@ -197,6 +204,8 @@ class HubertConfig(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
if ( if (
(len(self.conv_stride) != self.num_feat_extract_layers) (len(self.conv_stride) != self.num_feat_extract_layers)
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""
import argparse
import torch
from transformers import HubertConfig, HubertForSequenceClassification, Wav2Vec2FeatureExtractor, logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SUPPORTED_MODELS = ["UtteranceLevel"]
@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
downstream_dict = checkpoint["Downstream"]
hf_congfig = HubertConfig.from_pretrained(config_path)
hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)
if hf_congfig.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_model.projector.weight.data = downstream_dict["projector.weight"]
hf_model.projector.bias.data = downstream_dict["projector.bias"]
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
...@@ -20,12 +20,13 @@ import numpy as np ...@@ -20,12 +20,13 @@ import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_hubert import HubertConfig from .configuration_hubert import HubertConfig
...@@ -735,6 +736,18 @@ class HubertPreTrainedModel(PreTrainedModel): ...@@ -735,6 +736,18 @@ class HubertPreTrainedModel(PreTrainedModel):
return input_lengths return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
HUBERT_START_DOCSTRING = r""" HUBERT_START_DOCSTRING = r"""
Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
...@@ -904,19 +917,8 @@ class HubertModel(HubertPreTrainedModel): ...@@ -904,19 +917,8 @@ class HubertModel(HubertPreTrainedModel):
extract_features = extract_features.transpose(1, 2) extract_features = extract_features.transpose(1, 2)
if attention_mask is not None: if attention_mask is not None:
# compute real output lengths according to convolution formula # compute reduced attention_mask corresponding to feature vectors
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
attention_mask = torch.zeros(
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(extract_features) hidden_states = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
...@@ -1070,3 +1072,128 @@ class HubertForCTC(HubertPreTrainedModel): ...@@ -1070,3 +1072,128 @@ class HubertForCTC(HubertPreTrainedModel):
return CausalLMOutput( return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
@add_start_docstrings(
"""
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
HUBERT_START_DOCSTRING,
)
class HubertForSequenceClassification(HubertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
def __init__(self, config):
super().__init__(config)
self.hubert = HubertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.hubert.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.hubert.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example::
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
>>> from datasets import load_dataset
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
>>> # compute loss
>>> target_label = "down"
>>> labels = torch.tensor([model.config.label2id[target_label]])
>>> loss = model(input_values, labels=labels).loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[1]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining", "Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -66,6 +67,7 @@ if TYPE_CHECKING: ...@@ -66,6 +67,7 @@ if TYPE_CHECKING:
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining, Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -133,6 +133,11 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -133,6 +133,11 @@ class Wav2Vec2Config(PretrainedConfig):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`. instance of :class:`~transformers.Wav2Vec2ForCTC`.
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
...@@ -191,6 +196,8 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -191,6 +196,8 @@ class Wav2Vec2Config(PretrainedConfig):
diversity_loss_weight=0.1, diversity_loss_weight=0.1,
ctc_loss_reduction="sum", ctc_loss_reduction="sum",
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False, gradient_checkpointing=False,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
...@@ -223,6 +230,8 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -223,6 +230,8 @@ class Wav2Vec2Config(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
if ( if (
(len(self.conv_stride) != self.num_feat_extract_layers) (len(self.conv_stride) != self.num_feat_extract_layers)
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""
import argparse
import torch
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SUPPORTED_MODELS = ["UtteranceLevel"]
@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
downstream_dict = checkpoint["Downstream"]
hf_congfig = Wav2Vec2Config.from_pretrained(config_path)
hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)
if hf_congfig.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_model.projector.weight.data = downstream_dict["projector.weight"]
hf_model.projector.bias.data = downstream_dict["projector.bias"]
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
...@@ -83,9 +83,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -83,9 +83,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
""" """
Every array in the list is normalized to have zero mean and unit variance Every array in the list is normalized to have zero mean and unit variance
""" """
if isinstance(input_values[0], np.ndarray):
input_values = [x.astype(np.float32) for x in input_values]
normed_input_values = [ normed_input_values = [
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) (x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
] ]
...@@ -205,6 +202,9 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -205,6 +202,9 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
padded_input_values = padded_inputs["input_values"] padded_input_values = padded_inputs["input_values"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
if isinstance(padded_inputs["input_values"][0], np.ndarray):
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
# zero-mean and unit-variance normalization # zero-mean and unit-variance normalization
if self.do_normalize: if self.do_normalize:
padded_inputs["input_values"] = self.zero_mean_unit_var_norm( padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
...@@ -31,7 +32,7 @@ from ...file_utils import ( ...@@ -31,7 +32,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config from .configuration_wav2vec2 import Wav2Vec2Config
...@@ -1057,7 +1058,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1057,7 +1058,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
extract_features = extract_features.transpose(1, 2) extract_features = extract_features.transpose(1, 2)
if attention_mask is not None: if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors # compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states, extract_features = self.feature_projection(extract_features) hidden_states, extract_features = self.feature_projection(extract_features)
...@@ -1527,3 +1528,126 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -1527,3 +1528,126 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
return CausalLMOutput( return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
@add_start_docstrings(
"""
Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example::
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
>>> from datasets import load_dataset
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
>>> model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks")
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
>>> # compute loss
>>> target_label = "down"
>>> labels = torch.tensor([model.config.label2id[target_label]])
>>> loss = model(input_values, labels=labels).loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# End copy
if self.config.use_weighted_layer_sum:
hidden_states = outputs[2]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -1863,6 +1863,15 @@ class HubertForCTC: ...@@ -1863,6 +1863,15 @@ class HubertForCTC:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class HubertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HubertModel: class HubertModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -3473,6 +3482,15 @@ class Wav2Vec2ForPreTraining: ...@@ -3473,6 +3482,15 @@ class Wav2Vec2ForPreTraining:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Wav2Vec2ForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class Wav2Vec2Model: class Wav2Vec2Model:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init ...@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor from transformers import (
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
from transformers.models.hubert.modeling_hubert import _compute_mask_indices from transformers.models.hubert.modeling_hubert import _compute_mask_indices
...@@ -187,7 +193,32 @@ class HubertModelTester: ...@@ -187,7 +193,32 @@ class HubertModelTester:
self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float))
def check_training(self, config, input_values, *args): def check_seq_classifier_loss(self, config, input_values, *args):
model = HubertForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_values, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_values, *args):
config.ctc_zero_infinity = True config.ctc_zero_infinity = True
model = HubertForCTC(config=config) model = HubertForCTC(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -216,6 +247,29 @@ class HubertModelTester: ...@@ -216,6 +247,29 @@ class HubertModelTester:
loss.backward() loss.backward()
def check_seq_classifier_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = HubertForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args): def check_labels_out_of_vocab(self, config, input_values, *args):
model = HubertForCTC(config) model = HubertForCTC(config)
model.to(torch_device) model.to(torch_device)
...@@ -238,7 +292,7 @@ class HubertModelTester: ...@@ -238,7 +292,7 @@ class HubertModelTester:
@require_torch @require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase): class HubertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
...@@ -258,9 +312,17 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -258,9 +312,17 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self): def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -371,7 +433,7 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -371,7 +433,7 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
...@@ -397,9 +459,17 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -397,9 +459,17 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self): def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -557,6 +627,13 @@ class HubertModelIntegrationTest(unittest.TestCase): ...@@ -557,6 +627,13 @@ class HubertModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples] return ds["speech"][:num_samples]
def _load_superb(self, task, num_samples):
from datasets import load_dataset
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_batched(self): def test_inference_ctc_batched(self):
model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device) model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True) processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
...@@ -579,3 +656,95 @@ class HubertModelIntegrationTest(unittest.TestCase): ...@@ -579,3 +656,95 @@ class HubertModelIntegrationTest(unittest.TestCase):
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
] ]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_keyword_spotting(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [2, 6, 10, 9]
# s3prl logits for the same batch
expected_logits = torch.tensor([7.6692, 17.7795, 11.1562, 11.8232], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ic").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
expected_labels_action = [1, 0, 4, 3]
expected_logits_action = torch.tensor([5.9052, 12.5865, 4.4840, 10.0240], device=torch_device)
expected_labels_object = [1, 10, 3, 4]
expected_logits_object = torch.tensor([5.5316, 11.7946, 8.1672, 23.2415], device=torch_device)
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = torch.tensor([5.2053, 8.9577, 10.0447, 8.1481], device=torch_device)
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=3e-1))
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=3e-1))
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=3e-1))
def test_inference_speaker_identification(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
with torch.no_grad():
for example in input_data["speech"]:
input = processor(example, return_tensors="pt", padding=True)
output = model(input.input_values.to(torch_device), attention_mask=None)
output_logits.append(output.logits[0])
output_logits = torch.stack(output_logits)
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
expected_labels = [5, 1, 1, 3]
# s3prl logits for the same batch
expected_logits = torch.tensor([78231.5547, 123166.6094, 122785.4141, 84851.2969], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=10))
def test_inference_emotion_recognition(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = torch.tensor([2.8384, 2.3389, 3.8564, 4.5558], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-1))
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """ """ Testing suite for the PyTorch Wav2Vec2 model. """
import math import math
import unittest import unittest
...@@ -36,6 +35,7 @@ if is_torch_available(): ...@@ -36,6 +35,7 @@ if is_torch_available():
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining, Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2Processor, Wav2Vec2Processor,
) )
...@@ -194,7 +194,32 @@ class Wav2Vec2ModelTester: ...@@ -194,7 +194,32 @@ class Wav2Vec2ModelTester:
self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float))
def check_training(self, config, input_values, *args): def check_seq_classifier_loss(self, config, input_values, *args):
model = Wav2Vec2ForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_values, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_values, *args):
config.ctc_zero_infinity = True config.ctc_zero_infinity = True
model = Wav2Vec2ForCTC(config=config) model = Wav2Vec2ForCTC(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -223,6 +248,29 @@ class Wav2Vec2ModelTester: ...@@ -223,6 +248,29 @@ class Wav2Vec2ModelTester:
loss.backward() loss.backward()
def check_seq_classifier_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args): def check_labels_out_of_vocab(self, config, input_values, *args):
model = Wav2Vec2ForCTC(config) model = Wav2Vec2ForCTC(config)
model.to(torch_device) model.to(torch_device)
...@@ -246,7 +294,9 @@ class Wav2Vec2ModelTester: ...@@ -246,7 +294,9 @@ class Wav2Vec2ModelTester:
@require_torch @require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
...@@ -267,9 +317,17 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -267,9 +317,17 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self): def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -384,7 +442,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -384,7 +442,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
...@@ -411,9 +471,17 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -411,9 +471,17 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self): def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -691,6 +759,13 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -691,6 +759,13 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples] return ds["speech"][:num_samples]
def _load_superb(self, task, num_samples):
from datasets import load_dataset
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_normal(self): def test_inference_ctc_normal(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device) model.to(torch_device)
...@@ -795,7 +870,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -795,7 +870,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# fmt: off # fmt: off
expected_cosine_sim_masked = torch.tensor( expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997], [0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
0.6997],
device=torch_device, device=torch_device,
) )
# fmt: on # fmt: on
...@@ -913,3 +991,92 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -913,3 +991,92 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
expected_loss = 62.5170 expected_loss = 62.5170
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3) self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
def test_inference_keyword_spotting(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [7, 6, 10, 9]
# s3prl logits for the same batch
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
expected_labels_action = [0, 0, 2, 3]
expected_logits_action = torch.tensor([0.4568, 11.0848, 1.6621, 9.3841], device=torch_device)
expected_labels_object = [3, 10, 3, 4]
expected_logits_object = torch.tensor([1.5322, 10.7094, 5.2469, 22.1318], device=torch_device)
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = torch.tensor([1.5335, 6.5096, 10.5704, 11.0569], device=torch_device)
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
def test_inference_speaker_identification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
with torch.no_grad():
for example in input_data["speech"]:
input = processor(example, return_tensors="pt", padding=True)
output = model(input.input_values.to(torch_device), attention_mask=None)
output_logits.append(output.logits[0])
output_logits = torch.stack(output_logits)
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
expected_labels = [251, 1, 1, 3]
# s3prl logits for the same batch
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_emotion_recognition(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
...@@ -122,6 +122,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -122,6 +122,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"TFRagTokenForGeneration", "TFRagTokenForGeneration",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"HubertForCTC", "HubertForCTC",
"Wav2Vec2ForSequenceClassification",
"HubertForSequenceClassification",
"XLMForQuestionAnswering", "XLMForQuestionAnswering",
"XLNetForQuestionAnswering", "XLNetForQuestionAnswering",
"SeparableConv1D", "SeparableConv1D",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment