Unverified Commit c8545d2a authored by bofeng huang's avatar bofeng huang Committed by GitHub
Browse files

[Whisper] Add SpecAugment (#21298)



* Return and rescale attention_mask

* Add SpecAugment to Whisper modeling

* Fix test

* Update docstring

* Add SpecAug related parameters to model config

* Add the _mask_input_features function to doc

* Fix quality

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove dev comments

* Add test

* Resolve conflict

* feat: mask {feature, time} prob fast tests

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 75bd49ff
...@@ -72,6 +72,7 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -72,6 +72,7 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] WhisperModel [[autodoc]] WhisperModel
- forward - forward
- _mask_input_features
## WhisperForConditionalGeneration ## WhisperForConditionalGeneration
......
...@@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig): ...@@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`): begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
the token for `" "` (`blank_token_id`) and the `eos_token_id` the token for `" "` (`blank_token_id`) and the `eos_token_id`
apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
mask_time_prob (`float`, *optional*, defaults to 0.05):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2),:
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks''
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
Example: Example:
...@@ -185,6 +214,13 @@ class WhisperConfig(PretrainedConfig): ...@@ -185,6 +214,13 @@ class WhisperConfig(PretrainedConfig):
eos_token_id=50256, eos_token_id=50256,
suppress_tokens=None, suppress_tokens=None,
begin_suppress_tokens=[220, 50256], begin_suppress_tokens=[220, 50256],
apply_spec_augment=False,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -208,6 +244,14 @@ class WhisperConfig(PretrainedConfig): ...@@ -208,6 +244,14 @@ class WhisperConfig(PretrainedConfig):
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -307,6 +307,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -307,6 +307,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
) )
# make sure list is in array format # make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1) input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
...@@ -318,6 +319,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -318,6 +319,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
else: else:
padded_inputs["input_features"] = input_features padded_inputs["input_features"] = input_features
if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
......
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
...@@ -503,6 +624,14 @@ WHISPER_INPUTS_DOCSTRING = r""" ...@@ -503,6 +624,14 @@ WHISPER_INPUTS_DOCSTRING = r"""
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
...@@ -999,11 +1128,55 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -999,11 +1128,55 @@ class WhisperModel(WhisperPreTrainedModel):
""" """
self.encoder._freeze_parameters() self.encoder._freeze_parameters()
def _mask_input_features(
self,
input_features: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features
# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()
if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0
return input_features
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_features: Optional[torch.LongTensor] = None, input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
...@@ -1044,6 +1217,8 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -1044,6 +1217,8 @@ class WhisperModel(WhisperPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None: if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_features, input_features,
head_mask=head_mask, head_mask=head_mask,
...@@ -1139,7 +1314,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1139,7 +1314,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_features: Optional[torch.LongTensor] = None, input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
...@@ -1193,6 +1369,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1193,6 +1369,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
outputs = self.model( outputs = self.model(
input_features, input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
......
...@@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
expected_arg_names = [ expected_arg_names = [
"input_features", "input_features",
"attention_mask",
"decoder_input_ids", "decoder_input_ids",
"decoder_attention_mask", "decoder_attention_mask",
] ]
...@@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertEqual(fx_keys, pt_keys) self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_mask_feature_prob(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.mask_feature_prob = 0.2
config.mask_feature_length = 2
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.train()
# forward pass
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
def test_mask_time_prob(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.mask_time_prob = 0.2
config.mask_time_length = 2
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.train()
# forward pass
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
@require_torch @require_torch
@require_torchaudio @require_torchaudio
...@@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"
set_seed(0)
# Apply SpecAugment
model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
# Set model to training mode to enable SpecAugment
model.train()
model.to(torch_device)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.no_grad():
logits = model(
input_features,
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
use_cache=False,
)
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
-0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment