Unverified Commit 569743f5 authored by Kamil Akesbi's avatar Kamil Akesbi Committed by GitHub
Browse files

Add sdpa and fa2 the Wav2vec2 family. (#30121)



* add sdpa to wav2vec.
Co-authored-by: default avatarkamilakesbi <kamil@huggingface.co>
Co-authored-by: default avatarjp1924 <jp42maru@gmail.com>

* add fa2 to wav2vec2

* add tests

* fix attention_mask compatibility with fa2

* minor dtype fix

* replace fa2 slow test

* fix fa2 slow test

* apply code review + add fa2 batch test

* add sdpa and fa2 to hubert

* sdpa and fa2 to data2vec_audio

* sdpa and fa2 to Sew

* sdpa to unispeech + unispeech sat

* small fix

* attention mask in tests
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add_speedup_benchmark_to_doc

---------
Co-authored-by: default avatarkamil@huggingface.co <kamil.akesbi@gmail.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 367a0dbd
......@@ -44,6 +44,42 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
using [`Wav2Vec2CTCTokenizer`].
## Using Flash Attention 2
Flash Attention 2 is an faster, optimized version of the model.
### Installation
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
### Usage
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of `facebook/hubert-large-ls960-ft`, the flash-attention-2 and the sdpa (scale-dot-product-attention) version. We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
```python
>>> from transformers import Wav2Vec2Model
model = Wav2Vec2Model.from_pretrained("facebook/hubert-large-ls960-ft", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```
### Expected speedups
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of the `facebook/hubert-large-ls960-ft` model and the flash-attention-2 and sdpa (scale-dot-product-attention) versions. . We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/kamilakesbi/transformers_image_doc/resolve/main/data/Hubert_speedup.png">
</div>
## Resources
- [Audio classification task guide](../tasks/audio_classification)
......
......@@ -39,6 +39,42 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
using [`Wav2Vec2CTCTokenizer`].
## Using Flash Attention 2
Flash Attention 2 is an faster, optimized version of the model.
### Installation
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
### Usage
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
>>> from transformers import Wav2Vec2Model
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```
### Expected speedups
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of the `facebook/wav2vec2-large-960h-lv60-self` model and the flash-attention-2 and sdpa (scale-dot-product-attention) versions. . We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/kamilakesbi/transformers_image_doc/resolve/main/data/Wav2Vec2_speedup.png">
</div>
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Wav2Vec2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
......
......@@ -70,6 +70,12 @@ FlashAttention-2 is currently supported for the following architectures:
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
......@@ -203,6 +209,13 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
<Tip>
......
This diff is collapsed.
......@@ -1515,6 +1515,8 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices=mask_time_indices
)
quantized_features = quantized_features.to(self.project_q.weight.dtype)
quantized_features = self.project_q(quantized_features)
loss = contrastive_loss = diversity_loss = None
......
......@@ -25,6 +25,7 @@ import unittest
import numpy as np
from datasets import load_dataset
from pytest import mark
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
......@@ -33,9 +34,11 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
require_pyctcdecode,
require_soundfile,
require_torch,
require_torch_gpu,
require_torchaudio,
run_test_in_subprocess,
slow,
......@@ -1995,3 +1998,52 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
for lang in LANG_MAP.keys():
assert run_model(lang) == TRANSCRIPTIONS[lang]
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_inference_ctc_fa2(self):
model_fa = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
model_fa.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(1)
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
with torch.no_grad():
logits = model_fa(input_values.to(torch.bfloat16)).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_inference_ctc_fa2_batched(self):
model_fa = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
model_fa.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
inputs = inputs.to(torch_device)
with torch.no_grad():
logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment