README.md 15.1 KB
Newer Older
1
# FlashAttention
Tri Dao's avatar
Tri Dao committed
2
3
4
This repository provides the official implementation of FlashAttention and
FlashAttention-2 from the
following papers.
5

Tri Dao's avatar
Tri Dao committed
6
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**  
Tri Dao's avatar
Tri Dao committed
7
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré  
8
9
Paper: https://arxiv.org/abs/2205.14135  
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
10
![FlashAttention](assets/flashattn_banner.jpg)
11

Tri Dao's avatar
Tri Dao committed
12
13
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**  
Tri Dao
Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
Paper: https://tridao.me/publications/flash2/flash2.pdf
Tri Dao's avatar
Tri Dao committed
16

Tri Dao's avatar
Tri Dao committed
17
![FlashAttention-2](assets/flashattention_logo.png)
Tri Dao's avatar
Tri Dao committed
18
19


Tri Dao's avatar
Tri Dao committed
20
## Usage
Tri Dao's avatar
Tri Dao committed
21

Tri Dao's avatar
Tri Dao committed
22
23
24
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
Tri Dao's avatar
Tri Dao committed
25

Tri Dao's avatar
Tri Dao committed
26
27
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
Tri Dao's avatar
Tri Dao committed
28

Tri Dao's avatar
Tri Dao committed
29
## Installation and features
Tri Dao's avatar
Tri Dao committed
30

31
Requirements:
32
- CUDA 11.6 and above.
33
- PyTorch 1.12 and above.
Tri Dao's avatar
Tri Dao committed
34
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
35

Tri Dao's avatar
Tri Dao committed
36
37
38
39
We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
container from Nvidia, which has all the required tools to install FlashAttention.

40
To install:
Tri Dao's avatar
Tri Dao committed
41
42
43
44
45
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
Tri Dao's avatar
Tri Dao committed
46
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
Tri Dao's avatar
Tri Dao committed
47
48
49
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
50
```sh
Tri Dao's avatar
Tri Dao committed
51
pip install flash-attn --no-build-isolation
52
53
```
Alternatively you can compile from source:
Ian Timmis's avatar
Ian Timmis committed
54
```sh
Tri Dao's avatar
Tri Dao committed
55
python setup.py install
Tri Dao's avatar
Tri Dao committed
56
57
```

58
59
60
61
If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`:
Ian Timmis's avatar
Ian Timmis committed
62
```sh
63
64
65
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```

Tri Dao's avatar
Tri Dao committed
66
Interface: `src/flash_attention_interface.py`
Tri Dao's avatar
Tri Dao committed
67

Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
   GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
   GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
Tri Dao's avatar
Tri Dao committed
74

75
76
77

## How to use FlashAttention

Tri Dao's avatar
Tri Dao committed
78
79
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
Ian Timmis's avatar
Ian Timmis committed
80
```python
Tri Dao's avatar
Tri Dao committed
81
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
82
83
```

Ian Timmis's avatar
Ian Timmis committed
84
```python
Tri Dao's avatar
Tri Dao committed
85
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
Tri Dao's avatar
Tri Dao committed
86
87
88
89
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Tri Dao's avatar
Tri Dao committed
90
91
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
Arguments:
    qkv: (batch_size, seqlen, 3, nheads, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
98
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
99
100
Return:
    out: (batch_size, seqlen, nheads, headdim).
Ian Timmis's avatar
Ian Timmis committed
101
"""
Tri Dao's avatar
Tri Dao committed
102
```
103

Ian Timmis's avatar
Ian Timmis committed
104
```python
Tri Dao's avatar
Tri Dao committed
105
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
Tri Dao's avatar
Tri Dao committed
106
107
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Tri Dao's avatar
Tri Dao committed
108
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
109
110
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Tri Dao's avatar
Tri Dao committed
111
112
113
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Tri Dao's avatar
Tri Dao committed
114
115
116
117
118
119
120
121
122

Arguments:
    q: (batch_size, seqlen, nheads, headdim)
    k: (batch_size, seqlen, nheads_k, headdim)
    v: (batch_size, seqlen, nheads_k, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
123
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
124
125
Return:
    out: (batch_size, seqlen, nheads, headdim).
Ian Timmis's avatar
Ian Timmis committed
126
"""
127
128
```

Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
```python
def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    rotary_interleaved=True,
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.

    Note: Does not support backward pass.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
```

Tri Dao's avatar
Tri Dao committed
196
197
To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
198

Tri Dao's avatar
Tri Dao committed
199
200
## Changelog

201
### 2.0: Complete rewrite, 2x faster
Tri Dao's avatar
Tri Dao committed
202
Upgrading from FlashAttention (1.x) to FlashAttention-2
203

Tri Dao's avatar
Tri Dao committed
204
205
206
207
These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
208

Tri Dao's avatar
Tri Dao committed
209
210
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
Ian Timmis's avatar
Ian Timmis committed
211
```python
Tri Dao's avatar
Tri Dao committed
212
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
213
```
Ian Timmis's avatar
Ian Timmis committed
214
```python
Tri Dao's avatar
Tri Dao committed
215
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
216
```
217
### 2.1: Change behavior of causal flag
218
219
220
221

If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.

Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
masked out) is:  
v2.0:  
    1 0 0 0 0  
    1 1 0 0 0  
v2.1:  
    1 1 1 1 0  
    1 1 1 1 1  

If seqlen_q = 5 and seqlen_k = 2, the causal mask is:  
v2.0:  
    1 0  
    1 1  
    1 1  
    1 1  
    1 1  
v2.1:  
    0 0  
    0 0  
    0 0  
    1 0  
    1 1  
244
If the row of the mask is all zero, the output will be zero.
245

246
### 2.2: Optimize for inference
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
254
255
256
257
258

Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
cache as fast as possible, and we split the loading across different thread
blocks, with a separate kernel to combine results.

See the function `flash_attn_with_kvcache` with more features for inference
(perform rotary embedding, updating KV cache inplace).

Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.

259
### 2.3: Local (i.e., sliding window) attention
Tri Dao's avatar
Tri Dao committed
260
261
262
263
264

Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.

Tri Dao's avatar
Tri Dao committed
265
## Performance
Dan Fu's avatar
Dan Fu committed
266

Dan Fu's avatar
Dan Fu committed
267
268
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

Dan Fu's avatar
T4  
Dan Fu committed
269
270
We currently have benchmarks for these GPUs:
* [A100](#a100)
Tri Dao's avatar
Tri Dao committed
271
272
273
* [H100](#h100)
<!-- * [RTX 3090](#rtx-3090) -->
<!-- * [T4](#t4) -->
Dan Fu's avatar
T4  
Dan Fu committed
274

Dan Fu's avatar
Dan Fu committed
275
276
### A100

Tri Dao's avatar
Tri Dao committed
277
278
279
280
We display FlashAttention speedup using these parameters:
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
* Batch size set to 16k / seqlen.
Dan Fu's avatar
Dan Fu committed
281

Dan Fu's avatar
Dan Fu committed
282
#### Speedup
Dan Fu's avatar
Dan Fu committed
283

Tri Dao's avatar
Tri Dao committed
284
![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)
Dan Fu's avatar
Dan Fu committed
285

Dan Fu's avatar
Dan Fu committed
286
#### Memory
Dan Fu's avatar
Dan Fu committed
287

288
![FlashAttention memory](assets/flashattn_memory.jpg)
Dan Fu's avatar
Dan Fu committed
289
290
291
292
293

We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.
Tri Dao's avatar
Tri Dao committed
294

Tri Dao's avatar
Tri Dao committed
295
### H100
Dan Fu's avatar
Dan Fu committed
296

Tri Dao's avatar
Tri Dao committed
297
![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)
Dan Fu's avatar
Dan Fu committed
298

Tri Dao's avatar
Tri Dao committed
299
## Full model code and training script
Dan Fu's avatar
Dan Fu committed
300

Tri Dao's avatar
Tri Dao committed
301
302
303
304
305
306
307
We have released the full GPT model
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 225
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
any activation checkpointing).
Dan Fu's avatar
T4  
Dan Fu committed
308

Tri Dao's avatar
Tri Dao committed
309
310
311
We also include a training
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
Dan Fu's avatar
T4  
Dan Fu committed
312

Tri Dao's avatar
Tri Dao committed
313
## Triton implementation of FlashAttention
Dan Fu's avatar
T4  
Dan Fu committed
314

Tri Dao's avatar
Tri Dao committed
315
316
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Dan Fu's avatar
T4  
Dan Fu committed
317

Tri Dao's avatar
Tri Dao committed
318
319
320
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
Dan Fu's avatar
T4  
Dan Fu committed
321

Tri Dao's avatar
Tri Dao committed
322
323
324
We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
Dan Fu's avatar
T4  
Dan Fu committed
325
326


Tri Dao's avatar
Tri Dao committed
327
328
329
330
331
332
333
334
## Tests
We test that FlashAttention produces the same output and gradient as a reference
implementation, up to some numerical tolerance. In particular, we check that the
maximum numerical error of FlashAttention is at most twice the numerical error
of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal).

To run the tests:
Ian Timmis's avatar
Ian Timmis committed
335
```sh
Tri Dao's avatar
Tri Dao committed
336
337
pytest -q -s tests/test_flash_attn.py
```
Tri Dao's avatar
Tri Dao committed
338
339
## When you encounter issues

Tri Dao's avatar
Tri Dao committed
340
This new release of FlashAttention-2 has been tested on several GPT-style
Tri Dao's avatar
Tri Dao committed
341
models, mostly on A100 GPUs.
Tri Dao's avatar
Tri Dao committed
342

Tri Dao's avatar
Tri Dao committed
343
If you encounter bugs, please open a GitHub Issue!
Dan Fu's avatar
Dan Fu committed
344
345
346
347

## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
Tri Dao's avatar
Tri Dao committed
348
@inproceedings{dao2022flashattention,
349
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
Dan Fu's avatar
Dan Fu committed
350
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
Tri Dao's avatar
Tri Dao committed
351
  booktitle={Advances in Neural Information Processing Systems},
Dan Fu's avatar
Dan Fu committed
352
353
  year={2022}
}
Tri Dao's avatar
Tri Dao committed
354
@article{dao2023flashattention2,
355
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
Tri Dao's avatar
Tri Dao committed
356
357
358
  author={Dao, Tri},
  year={2023}
}
Dan Fu's avatar
Dan Fu committed
359
```