You need to sign in or sign up before continuing.
Unverified Commit 80377eb0 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

F.scaled_dot_product_attention support (#26572)



* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ce0bbd51
...@@ -441,7 +441,7 @@ flush() ...@@ -441,7 +441,7 @@ flush()
``` ```
For comparison, let's run the same function, but enable Flash Attention instead. For comparison, let's run the same function, but enable Flash Attention instead.
To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention. To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.
```python ```python
model.to_bettertransformer() model.to_bettertransformer()
......
...@@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation ...@@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation
##### Usage ##### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python ```python
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
``` ```
##### Performance comparison ##### Performance comparison
...@@ -114,7 +114,7 @@ import torch ...@@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
# load in fp16 and use Flash Attention 2 # load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
# enable CPU offload # enable CPU offload
model.enable_cpu_offload() model.enable_cpu_offload()
......
...@@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto >>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') >>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True) >>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> text = "Replace me by any text you'd like." >>> text = "Replace me by any text you'd like."
......
...@@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto >>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True) >>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") >>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
>>> prompt = "def hello_world():" >>> prompt = "def hello_world():"
......
...@@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto >>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True) >>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
>>> prompt = "def hello_world():" >>> prompt = "def hello_world():"
......
...@@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation ...@@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation
### Usage ### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python ```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast >>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
... ...
``` ```
......
...@@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto >>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True) >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> prompt = "My favourite condiment is" >>> prompt = "My favourite condiment is"
......
...@@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer >>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto >>> device = "cuda" # the device to load the model onto
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True) >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the " >>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
......
...@@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below: ...@@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer >>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenizer and push the model and tokens to the GPU. >>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda") >>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") >>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> # feel free to change the prompt to your liking. >>> # feel free to change the prompt to your liking.
...@@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t ...@@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward - forward
</pt> </pt>
</frameworkcontent> </frameworkcontent>
\ No newline at end of file
...@@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions. ...@@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length 1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them 2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs. FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]: To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
```python ```python
import torch import torch
...@@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
``` ```
<Tip> <Tip>
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.
</Tip> </Tip>
...@@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
load_in_8bit=True, load_in_8bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
# load in 4bit # load in 4bit
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
load_in_4bit=True, load_in_4bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
``` ```
...@@ -124,41 +142,21 @@ FlashAttention is more memory efficient, meaning you can train on much larger se ...@@ -124,41 +142,21 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png"> <img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div> </div>
## BetterTransformer ## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention
<Tip>
Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
</Tip>
BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:
1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors
BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.
Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).
Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:
```python
model = model.to_bettertransformer()
```
You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
```py PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
### FlashAttention For now, Transformers supports inference and training through SDPA for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it. Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.
To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
```diff ```diff
import torch import torch
...@@ -187,6 +185,43 @@ RuntimeError: No available kernel. Aborting execution. ...@@ -187,6 +185,43 @@ RuntimeError: No available kernel. Aborting execution.
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
``` ```
## BetterTransformer
<Tip warning={true}>
Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.
</Tip>
<Tip>
Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
</Tip>
BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:
1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors
BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.
Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).
Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:
```python
model = model.to_bettertransformer()
```
You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
## bitsandbytes ## bitsandbytes
bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory. bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
......
...@@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one ...@@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
```py ```py
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0") model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
``` ```
......
...@@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの ...@@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの
### Quick usage ### Quick usage
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。 モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。
```python ```python
...@@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
``` ```
...@@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
load_in_8bit=True, load_in_8bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
``` ```
...@@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
load_in_4bit=True, load_in_4bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
``` ```
...@@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
load_in_4bit=True, load_in_4bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
lora_config = LoraConfig( lora_config = LoraConfig(
......
...@@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0") ...@@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
### 结合 AWQ 和 Flash Attention ### 结合 AWQ 和 Flash Attention
您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`use_flash_attention_2=True`参数。 您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`attn_implementation="flash_attention_2"`参数。
```python ```python
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0") model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
``` ```
### 基准测试 ### 基准测试
......
...@@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin):
This attribute is currently not being used during model loading time, but this may change in the future This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained. versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
attn_implementation (`str`, *optional*):
The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> TensorFlow specific parameters > TensorFlow specific parameters
...@@ -374,6 +376,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -374,6 +376,9 @@ class PretrainedConfig(PushToHubMixin):
# Config hash # Config hash
self._commit_hash = kwargs.pop("_commit_hash", None) self._commit_hash = kwargs.pop("_commit_hash", None)
# Attention implementation to use, if relevant.
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
# Drop the transformers version info # Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None) self.transformers_version = kwargs.pop("transformers_version", None)
...@@ -422,6 +427,22 @@ class PretrainedConfig(PushToHubMixin): ...@@ -422,6 +427,22 @@ class PretrainedConfig(PushToHubMixin):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
...@@ -747,6 +768,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -747,6 +768,9 @@ class PretrainedConfig(PushToHubMixin):
if "_commit_hash" in kwargs and "_commit_hash" in config_dict: if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"] kwargs["_commit_hash"] = config_dict["_commit_hash"]
# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
config = cls(**config_dict) config = cls(**config_dict)
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
...@@ -861,8 +885,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -861,8 +885,8 @@ class PretrainedConfig(PushToHubMixin):
self.dict_torch_dtype_to_str(serializable_config_dict) self.dict_torch_dtype_to_str(serializable_config_dict)
if "_flash_attn_2_enabled" in serializable_config_dict: if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_flash_attn_2_enabled"] del serializable_config_dict["_attn_implementation_internal"]
return serializable_config_dict return serializable_config_dict
...@@ -880,8 +904,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -880,8 +904,8 @@ class PretrainedConfig(PushToHubMixin):
del output["_auto_class"] del output["_auto_class"]
if "_commit_hash" in output: if "_commit_hash" in output:
del output["_commit_hash"] del output["_commit_hash"]
if "_flash_attn_2_enabled" in output: if "_attn_implementation_internal" in output:
del output["_flash_attn_2_enabled"] del output["_attn_implementation_internal"]
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
......
...@@ -68,7 +68,7 @@ class AttentionMaskConverter: ...@@ -68,7 +68,7 @@ class AttentionMaskConverter:
key_value_length: int, key_value_length: int,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu", device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
""" """
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask). bias to upper right hand triangular matrix (causal mask).
...@@ -184,6 +184,95 @@ class AttentionMaskConverter: ...@@ -184,6 +184,95 @@ class AttentionMaskConverter:
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def _unmask_unattended(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
):
# fmt: off
"""
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Details: https://github.com/pytorch/pytorch/issues/110213
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `attention_mask` is
```
[[0, 0, 1],
[1, 1, 1],
[0, 1, 1]]
```
and `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 0, 0],
[0, 1, 0],
[0, 1, 1]]]]
```
then the modified `expanded_mask` will be
```
[[[[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]]]]
```
"""
# fmt: on
# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
left_masked_rows = torch.where(indices > 0)[0]
if left_masked_rows.shape[0] == 0:
return expanded_mask
indices = indices[left_masked_rows]
max_len = torch.max(indices)
range_tensor = torch.arange(max_len).unsqueeze(0)
range_tensor = range_tensor.repeat(indices.size(0), 1)
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
range_tensor[range_tensor >= indices] = 0
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
if expanded_mask.dim() == 4:
num_masks = expanded_mask.shape[1]
if num_masks == 1:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
else:
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
mask_slice = (
left_masked_rows[:, None, None],
torch.arange(num_masks)[None, :, None],
range_tensor[:, None, :],
)
else:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], range_tensor)
expanded_mask[mask_slice] = unmasked_value
return expanded_mask
def _prepare_4d_causal_attention_mask( def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
...@@ -225,6 +314,78 @@ def _prepare_4d_causal_attention_mask( ...@@ -225,6 +314,78 @@ def _prepare_4d_causal_attention_mask(
return attention_mask return attention_mask
# Adapted from _prepare_4d_causal_attention_mask
def _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
if attention_mask is not None:
if torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass
elif query_length > 1 and key_value_length != query_length:
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
if attention_mask is None:
expanded_4d_mask = None
elif attention_mask is True:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if query_length > 1:
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
)
return expanded_4d_mask
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
""" """
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
...@@ -241,13 +402,51 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: ...@@ -241,13 +402,51 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
batch_size, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _create_4d_causal_attention_mask( def _create_4d_causal_attention_mask(
input_shape: Union[torch.Size, Tuple, List], input_shape: Union[torch.Size, Tuple, List],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
past_key_values_length: int = 0, past_key_values_length: int = 0,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
): ) -> Optional[torch.Tensor]:
""" """
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
......
...@@ -81,6 +81,7 @@ from .utils import ( ...@@ -81,6 +81,7 @@ from .utils import (
is_peft_available, is_peft_available,
is_remote_url, is_remote_url,
is_safetensors_available, is_safetensors_available,
is_torch_sdpa_available,
is_torch_tpu_available, is_torch_tpu_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
...@@ -1128,6 +1129,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1128,6 +1129,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support # Flash Attention 2 support
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
# SDPA support
_supports_sdpa = False
# Has support for a `Cache` instance as `past_key_values` # Has support for a `Cache` instance as `past_key_values`
_supports_cache_class = False _supports_cache_class = False
...@@ -1154,7 +1158,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1154,7 +1158,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
) )
# Save config and origin of the pretrained weights if given in model # Save config and origin of the pretrained weights if given in model
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
self.warnings_issued = {} self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
...@@ -1185,8 +1193,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1185,8 +1193,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Args: Args:
torch_dtype (`torch.dtype`, *optional*): torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. Override the default `torch.dtype` and load the model under this dtype.
use_flash_attention_2 (`bool`, *optional*):
Whether to load the model with Flash Attention 2 modules.
""" """
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
...@@ -1196,8 +1202,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1196,8 +1202,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if torch_dtype is not None: if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype) dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if use_flash_attention_2: config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config = cls._check_and_enable_flash_attn_2(config, torch_dtype) config._attn_implementation = kwargs.pop("attn_implementation", None)
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
)
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed import deepspeed
...@@ -1216,6 +1225,67 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1216,6 +1225,67 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model return model
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
"""
Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
4. The default model's implementation otherwise (`LlamaAttention` for example) .
"""
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
raise ValueError(
f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
hard_check_only = True
else:
hard_check_only = False
if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
)
config._attn_implementation = "flash_attention_2"
if config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=hard_check_only,
check_device_map=check_device_map,
)
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
elif not hard_check_only:
config._attn_implementation = "eager"
return config
@classmethod @classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
""" """
...@@ -1266,38 +1336,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1266,38 +1336,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@classmethod @classmethod
def _check_and_enable_flash_attn_2( def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig: ) -> PretrainedConfig:
""" """
If you don't know about Flash Attention, check out the official repository of flash attention: Checks the availability of Flash Attention 2 and compatibility with the current model.
https://github.com/Dao-AILab/flash-attention
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
""" """
if not cls._supports_flash_attn_2: if not cls._supports_flash_attn_2:
raise ValueError( raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to " f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new" "request support for this architecture: https://github.com/huggingface/transformers/issues/new"
) )
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
if torch.version.cuda:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"): if flash_attention_version < version.parse("2.1.0"):
raise ImportError( raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
...@@ -1305,9 +1370,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1305,9 +1370,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip: elif torch.version.hip:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.0.4"): if flash_attention_version < version.parse("2.0.4"):
raise ImportError( raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
...@@ -1332,20 +1394,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1332,20 +1394,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" unexpected behaviour." " unexpected behaviour."
) )
if device_map is None: # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.warning( logger.warning(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU" "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`." " after initializing it on CPU with `model.to('cuda')`."
) )
else: else:
raise ValueError( raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. " "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU." "or initialising the model on CPU and then moving it to GPU."
) )
elif ( elif (
device_map is not None check_device_map
and device_map is not None
and isinstance(device_map, dict) and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values()) and ("cpu" in device_map.values() or "disk" in device_map.values())
): ):
...@@ -1353,7 +1418,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1353,7 +1418,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
) )
config._flash_attn_2_enabled = True if not hard_check_only:
config._attn_implementation = "flash_attention_2"
return config
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_torch_sdpa_available():
raise ImportError(
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
)
if not is_torch_sdpa_available() or not cls._supports_sdpa:
return config
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config return config
def enable_input_require_grads(self): def enable_input_require_grads(self):
...@@ -3312,8 +3407,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3312,8 +3407,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights()) init_contexts.append(init_empty_weights())
if use_flash_attention_2: config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
......
...@@ -389,7 +389,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention): ...@@ -389,7 +389,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
BARK_ATTENTION_CLASSES = { BARK_ATTENTION_CLASSES = {
"default": BarkSelfAttention, "eager": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2, "flash_attention_2": BarkSelfFlashAttention2,
} }
...@@ -436,8 +436,7 @@ class BarkBlock(nn.Module): ...@@ -436,8 +436,7 @@ class BarkBlock(nn.Module):
self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size)
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
self.mlp = BarkMLP(config) self.mlp = BarkMLP(config)
...@@ -670,6 +669,7 @@ class BarkCausalModel(BarkPreTrainedModel): ...@@ -670,6 +669,7 @@ class BarkCausalModel(BarkPreTrainedModel):
self.drop = nn.Dropout(config.dropout) self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
...@@ -805,7 +805,7 @@ class BarkCausalModel(BarkPreTrainedModel): ...@@ -805,7 +805,7 @@ class BarkCausalModel(BarkPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
if batch_size <= 0: if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0") raise ValueError("batch_size has to be defined and > 0")
if getattr(self.config, "_flash_attn_2_enabled", False): if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None attention_mask = attention_mask if 0 in attention_mask else None
else: else:
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
...@@ -1265,6 +1265,7 @@ class BarkFineModel(BarkPreTrainedModel): ...@@ -1265,6 +1265,7 @@ class BarkFineModel(BarkPreTrainedModel):
self.drop = nn.Dropout(config.dropout) self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = nn.LayerNorm(config.hidden_size) self.layernorm_final = nn.LayerNorm(config.hidden_size)
...@@ -1434,7 +1435,7 @@ class BarkFineModel(BarkPreTrainedModel): ...@@ -1434,7 +1435,7 @@ class BarkFineModel(BarkPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
if batch_size <= 0: if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0") raise ValueError("batch_size has to be defined and > 0")
if getattr(self.config, "_flash_attn_2_enabled", False): if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None attention_mask = attention_mask if 0 in attention_mask else None
else: else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
...@@ -1875,7 +1876,11 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1875,7 +1876,11 @@ class BarkModel(BarkPreTrainedModel):
@classmethod @classmethod
def _check_and_enable_flash_attn_2( def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
hard_check_only: bool = False,
): ):
""" """
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
...@@ -1892,12 +1897,14 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1892,12 +1897,14 @@ class BarkModel(BarkPreTrainedModel):
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU. half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model
can initialize the correct attention module can initialize the correct attention module
""" """
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) config = super()._check_and_enable_flash_attn_2(
config, torch_dtype, device_map, hard_check_only=hard_check_only
)
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) config.semantic_config._attn_implementation = config._attn_implementation
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config return config
...@@ -25,7 +25,12 @@ from torch import nn ...@@ -25,7 +25,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -505,8 +510,109 @@ class BartFlashAttention2(BartAttention): ...@@ -505,8 +510,109 @@ class BartFlashAttention2(BartAttention):
) )
class BartSdpaAttention(BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
BART_ATTENTION_CLASSES = { BART_ATTENTION_CLASSES = {
"default": BartAttention, "eager": BartAttention,
"sdpa": BartSdpaAttention,
"flash_attention_2": BartFlashAttention2, "flash_attention_2": BartFlashAttention2,
} }
...@@ -515,9 +621,8 @@ class BartEncoderLayer(nn.Module): ...@@ -515,9 +621,8 @@ class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BART_ATTENTION_CLASSES[attn_type]( self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -587,8 +692,7 @@ class BartDecoderLayer(nn.Module): ...@@ -587,8 +692,7 @@ class BartDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -601,7 +705,7 @@ class BartDecoderLayer(nn.Module): ...@@ -601,7 +705,7 @@ class BartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BART_ATTENTION_CLASSES[attn_type]( self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -735,6 +839,7 @@ class BartPreTrainedModel(PreTrainedModel): ...@@ -735,6 +839,7 @@ class BartPreTrainedModel(PreTrainedModel):
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -961,6 +1066,8 @@ class BartEncoder(BartPreTrainedModel): ...@@ -961,6 +1066,8 @@ class BartEncoder(BartPreTrainedModel):
embed_dim, embed_dim,
) )
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -1048,8 +1155,13 @@ class BartEncoder(BartPreTrainedModel): ...@@ -1048,8 +1155,13 @@ class BartEncoder(BartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False): if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else: else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
...@@ -1136,6 +1248,9 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1136,6 +1248,9 @@ class BartDecoder(BartPreTrainedModel):
config.d_model, config.d_model,
) )
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -1254,9 +1369,18 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1254,9 +1369,18 @@ class BartDecoder(BartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False): if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else: else:
# 4d mask is passed through the layers # 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
...@@ -1265,8 +1389,17 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1265,8 +1389,17 @@ class BartDecoder(BartPreTrainedModel):
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False): if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else: else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask = _prepare_4d_attention_mask(
......
...@@ -252,7 +252,7 @@ class BlenderbotAttention(nn.Module): ...@@ -252,7 +252,7 @@ class BlenderbotAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention} BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
...@@ -260,9 +260,8 @@ class BlenderbotEncoderLayer(nn.Module): ...@@ -260,9 +260,8 @@ class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -332,9 +331,8 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -332,9 +331,8 @@ class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -347,7 +345,7 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -347,7 +345,7 @@ class BlenderbotDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
......
...@@ -254,9 +254,8 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -254,9 +254,8 @@ class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -321,7 +320,10 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -321,7 +320,10 @@ class BlenderbotSmallEncoderLayer(nn.Module):
return outputs return outputs
BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention} # TODO: Implement attention with SDPA for TimeSeriesTransformer.
BLENDERBOT_SMALL_ATTENTION_CLASSES = {
"eager": BlenderbotSmallAttention,
}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
...@@ -330,8 +332,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -330,8 +332,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -344,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -344,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment