"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "002915aa2ad7d7826294fabd7ba4e6297772768c"
Unverified Commit 7c393181 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Add model for audio classification (#21754)

* [Whisper] Add model for audio classification

* make fix-copies

* add to docs

* add docstring

* empty returns

* add code example

* switch to fleurs

* stick everything on one line
parent 9402788b
...@@ -79,6 +79,11 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -79,6 +79,11 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] WhisperForConditionalGeneration [[autodoc]] WhisperForConditionalGeneration
- forward - forward
## WhisperForAudioClassification
[[autodoc]] WhisperForAudioClassification
- forward
## TFWhisperModel ## TFWhisperModel
......
...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm) [Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm), [Whisper](../model_doc/whisper)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -2575,6 +2575,7 @@ else: ...@@ -2575,6 +2575,7 @@ else:
_import_structure["models.whisper"].extend( _import_structure["models.whisper"].extend(
[ [
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", "WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
"WhisperForAudioClassification",
"WhisperForConditionalGeneration", "WhisperForConditionalGeneration",
"WhisperModel", "WhisperModel",
"WhisperPreTrainedModel", "WhisperPreTrainedModel",
...@@ -5782,6 +5783,7 @@ if TYPE_CHECKING: ...@@ -5782,6 +5783,7 @@ if TYPE_CHECKING:
) )
from .models.whisper import ( from .models.whisper import (
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
WhisperForAudioClassification,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperModel, WhisperModel,
WhisperPreTrainedModel, WhisperPreTrainedModel,
......
...@@ -877,6 +877,7 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -877,6 +877,7 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("wav2vec2", "Wav2Vec2ForSequenceClassification"), ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"), ("wavlm", "WavLMForSequenceClassification"),
("whisper", "WhisperForAudioClassification"),
] ]
) )
......
...@@ -49,6 +49,7 @@ else: ...@@ -49,6 +49,7 @@ else:
"WhisperForConditionalGeneration", "WhisperForConditionalGeneration",
"WhisperModel", "WhisperModel",
"WhisperPreTrainedModel", "WhisperPreTrainedModel",
"WhisperForAudioClassification",
] ]
try: try:
...@@ -99,6 +100,7 @@ if TYPE_CHECKING: ...@@ -99,6 +100,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_whisper import ( from .modeling_whisper import (
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
WhisperForAudioClassification,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperModel, WhisperModel,
WhisperPreTrainedModel, WhisperPreTrainedModel,
......
...@@ -136,6 +136,12 @@ class WhisperConfig(PretrainedConfig): ...@@ -136,6 +136,12 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`): begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
the token for `" "` (`blank_token_id`) and the `eos_token_id` the token for `" "` (`blank_token_id`) and the `eos_token_id`
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of [`WhisperForAudioClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Only relevant when using an
instance of [`WhisperForAudioClassification`].
apply_spec_augment (`bool`, *optional*, defaults to `False`): apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
...@@ -214,6 +220,8 @@ class WhisperConfig(PretrainedConfig): ...@@ -214,6 +220,8 @@ class WhisperConfig(PretrainedConfig):
eos_token_id=50256, eos_token_id=50256,
suppress_tokens=None, suppress_tokens=None,
begin_suppress_tokens=[220, 50256], begin_suppress_tokens=[220, 50256],
use_weighted_layer_sum=False,
classifier_proj_size=256,
apply_spec_augment=False, apply_spec_augment=False,
mask_time_prob=0.05, mask_time_prob=0.05,
mask_time_length=10, mask_time_length=10,
...@@ -244,6 +252,11 @@ class WhisperConfig(PretrainedConfig): ...@@ -244,6 +252,11 @@ class WhisperConfig(PretrainedConfig):
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
# Audio Classification-specific parameters. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
self.use_weighted_layer_sum = use_weighted_layer_sum
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob self.mask_time_prob = mask_time_prob
......
...@@ -32,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,6 +32,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
SequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
...@@ -701,6 +702,33 @@ WHISPER_INPUTS_DOCSTRING = r""" ...@@ -701,6 +702,33 @@ WHISPER_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class WhisperEncoder(WhisperPreTrainedModel): class WhisperEncoder(WhisperPreTrainedModel):
""" """
...@@ -1578,3 +1606,123 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1578,3 +1606,123 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
@add_start_docstrings(
"""
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
""",
WHISPER_ENCODER_INPUTS_DOCSTRING,
)
class WhisperForAudioClassification(WhisperPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.encoder = WhisperEncoder(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training. Only the projection layers and classification head will be updated.
"""
self.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))
>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
... )
>>> input_features = inputs.input_features
>>> with torch.no_grad():
... logits = model(input_features).logits
>>> predicted_class_ids = torch.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'af_za'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
...@@ -6797,6 +6797,13 @@ class WavLMPreTrainedModel(metaclass=DummyObject): ...@@ -6797,6 +6797,13 @@ class WavLMPreTrainedModel(metaclass=DummyObject):
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class WhisperForAudioClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperForConditionalGeneration(metaclass=DummyObject): class WhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -43,6 +43,7 @@ if is_torch_available(): ...@@ -43,6 +43,7 @@ if is_torch_available():
from transformers import ( from transformers import (
WhisperFeatureExtractor, WhisperFeatureExtractor,
WhisperForAudioClassification,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperModel, WhisperModel,
WhisperProcessor, WhisperProcessor,
...@@ -1372,3 +1373,191 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1372,3 +1373,191 @@ class WhisperModelIntegrationTests(unittest.TestCase):
) )
# fmt: on # fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4)) self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
return {"input_features": input_features, "head_mask": head_mask}
@require_torch
class WhisperEncoderModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=60,
is_training=True,
use_labels=True,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
input_channels=1,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=30,
num_mel_bins=80,
num_conv_layers=1,
suppress_tokens=None,
begin_suppress_tokens=None,
classifier_proj_size=4,
num_labels=2,
is_encoder_decoder=False,
is_decoder=False,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.input_channels = input_channels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_mel_bins = num_mel_bins
self.max_position_embeddings = max_position_embeddings
self.max_source_positions = max_source_positions
self.num_conv_layers = num_conv_layers
self.suppress_tokens = suppress_tokens
self.begin_suppress_tokens = begin_suppress_tokens
self.classifier_proj_size = classifier_proj_size
self.num_labels = num_labels
self.is_encoder_decoder = is_encoder_decoder
self.is_decoder = is_decoder
def get_config(self):
return WhisperConfig(
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
input_channels=self.input_channels,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
max_source_positions=self.max_source_positions,
decoder_ffn_dim=self.hidden_size,
encoder_ffn_dim=self.hidden_size,
suppress_tokens=self.suppress_tokens,
begin_suppress_tokens=self.begin_suppress_tokens,
classifier_proj_size=self.classifier_proj_size,
num_labels=self.num_labels,
is_encoder_decoder=self.is_encoder_decoder,
is_decoder=self.is_decoder,
)
def prepare_config_and_inputs(self):
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
config = self.get_config()
inputs_dict = prepare_whisper_encoder_inputs_dict(
config,
input_features=input_features,
)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for i in range(self.num_conv_layers):
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
@property
def encoder_seq_length(self):
return self.get_subsampled_output_lengths(self.seq_length)
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
model = WhisperForAudioClassification(config=config).to(torch_device).eval()
if freeze_encoder:
model.freeze_encoder()
input_features = inputs_dict["input_features"]
# first forward pass
last_hidden_state = model(input_features).logits
self.parent.assertTrue(last_hidden_state.shape, (13, 2))
@require_torch
class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (WhisperForAudioClassification,) if is_torch_available() else ()
is_encoder_decoder = False
fx_compatible = False
test_pruning = False
test_missing_keys = False
input_name = "input_features"
def setUp(self):
self.model_tester = WhisperEncoderModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
self.maxDiff = 3000
def test_config(self):
self.config_tester.run_common_tests()
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
# input embeds is meaningless for an encoder-only acoustic model
def test_inputs_embeds(self):
pass
# the equivalent test is passing the encoder outputs directly to the model
def test_encoder_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
outputs = model(**inputs)[0]
input_ids = inputs["input_features"]
del inputs["input_features"]
encoder = model.encoder
with torch.no_grad():
inputs["encoder_outputs"] = encoder(input_ids)
outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all())
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
def test_model_common_attributes(self):
pass
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment