Commit d9fff84b authored by Tri Dao's avatar Tri Dao
Browse files

Edit roadmap

parent e4ffe5d5
......@@ -2,7 +2,7 @@
This repository provides the official implementation of FlashAttention from the
following paper.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness***
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
![FlashAttention](assets/flashattn_banner.jpg)
......@@ -14,7 +14,7 @@ cd csrc/flash_attn
python setup.py install
```
Interface: `flash_attention.py`
Interface: `src/flash_attention.py`
To run the benchmark against PyTorch standard attention:
```
......@@ -26,17 +26,18 @@ FlashAttention currently supports:
2. fp16.
3. Head dimensions 16, 32, 64.
Our roadmap to broaden the support:
1. Refactor to use Cutlass.
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
3. Support bf16.
4. Support head dimension 128.
5. Make package pip-installable.
6. Support SM70 GPUs (V100).
7. Fused rotary embedding.
8. Attention linear bias (e.g. ALiBi).
Our tentative roadmap:
1. [Jun 2022] Make package pip-installable.
2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090).
3. [Jun 2022] Refactor to use Cutlass.
4. [Jun 2022] Support SM75 GPUs (e.g. T4).
5. [Jun 2022] Support bf16.
6. [Jul 2022] Support head dimension 128.
7. [Jul 2022] Support SM70 GPUs (V100).
8. [Aug 2022] Fuse rotary embedding.
9. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
### Speedup and Memory Savings
## Speedup and Memory Savings
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length.
We display FlashAttention speedup using these parameters (similar to BERT-base):
......@@ -46,14 +47,14 @@ We display FlashAttention speedup using these parameters (similar to BERT-base):
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
#### Speedup
### Speedup
![FlashAttention speedup](assets/flashattn_speedup.jpg)
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
#### Memory
### Memory
![FlashAttention memory](assets/flashattn_memory.jpg)
......@@ -62,7 +63,7 @@ Memory savings are proportional to sequence length -- since standard attention h
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.
### Acknowledgments
## Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment