This repository provides the official implementation of FlashAttention from the
This repository provides the official implementation of FlashAttention from the
following paper.
following paper.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness***
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré


...
@@ -14,7 +14,7 @@ cd csrc/flash_attn
...
@@ -14,7 +14,7 @@ cd csrc/flash_attn
python setup.py install
python setup.py install
```
```
Interface: `flash_attention.py`
Interface: `src/flash_attention.py`
To run the benchmark against PyTorch standard attention:
To run the benchmark against PyTorch standard attention:
```
```
...
@@ -26,17 +26,18 @@ FlashAttention currently supports:
...
@@ -26,17 +26,18 @@ FlashAttention currently supports:
2. fp16.
2. fp16.
3. Head dimensions 16, 32, 64.
3. Head dimensions 16, 32, 64.
Our roadmap to broaden the support:
Our tentative roadmap:
1. Refactor to use Cutlass.
1. [Jun 2022] Make package pip-installable.
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090).
3. Support bf16.
3. [Jun 2022] Refactor to use Cutlass.
4. Support head dimension 128.
4. [Jun 2022] Support SM75 GPUs (e.g. T4).
5. Make package pip-installable.
5. [Jun 2022] Support bf16.
6. Support SM70 GPUs (V100).
6. [Jul 2022] Support head dimension 128.
7. Fused rotary embedding.
7. [Jul 2022] Support SM70 GPUs (V100).
8. Attention linear bias (e.g. ALiBi).
8. [Aug 2022] Fuse rotary embedding.
9. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
### Speedup and Memory Savings
## Speedup and Memory Savings
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length.
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length.
We display FlashAttention speedup using these parameters (similar to BERT-base):
We display FlashAttention speedup using these parameters (similar to BERT-base):
...
@@ -46,14 +47,14 @@ We display FlashAttention speedup using these parameters (similar to BERT-base):
...
@@ -46,14 +47,14 @@ We display FlashAttention speedup using these parameters (similar to BERT-base):
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.