README.md 1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
## FlashAttention - Alpha release (0.1).
Tri Dao's avatar
Tri Dao committed
2

Tri Dao's avatar
Tri Dao committed
3
To compile (requiring NVCC and an A100 GPU):
Tri Dao's avatar
Tri Dao committed
4
```
Tri Dao's avatar
Tri Dao committed
5
cd csrc/flash_attn
Tri Dao's avatar
Tri Dao committed
6
7
8
python setup.py install
```

Tri Dao's avatar
Tri Dao committed
9
Interface: `flash_attention.py`
Tri Dao's avatar
Tri Dao committed
10

Tri Dao's avatar
Tri Dao committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
To run the benchmark against PyTorch standard attention: 
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```

FlashAttention currently supports:
1. A100 GPUs.
2. fp16.
3. Head dimensions 16, 32, 64.

Our roadmap to broaden the support:
1. Refactor to use Cutlass.
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
3. Support bf16.
4. Support head dimension 128.
5. Make package pip-installable.
6. Support SM70 GPUs (V100).
7. Fused rotary embedding.
8. Attention linear bias (e.g. ALiBi).


### Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.

We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.