Commit 4f285b35 authored by Tri Dao's avatar Tri Dao
Browse files

FlashAttention-2 release

parent 6d48e14a
[submodule "csrc/flash_attn/cutlass"] [submodule "csrc/cutlass"]
path = csrc/flash_attn/cutlass path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git url = https://github.com/NVIDIA/cutlass.git
Tri Dao, trid@stanford.edu Tri Dao, trid@cs.stanford.edu
Dan Fu, danfu@cs.stanford.edu \ No newline at end of file
\ No newline at end of file
...@@ -2,8 +2,10 @@ recursive-include csrc *.cu ...@@ -2,8 +2,10 @@ recursive-include csrc *.cu
recursive-include csrc *.h recursive-include csrc *.h
recursive-include csrc *.cuh recursive-include csrc *.cuh
recursive-include csrc *.cpp recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include flash_attn *.cu recursive-include flash_attn *.cu
recursive-include flash_attn *.h recursive-include flash_attn *.h
recursive-include flash_attn *.cuh recursive-include flash_attn *.cuh
recursive-include flash_attn *.cpp recursive-include flash_attn *.cpp
recursive-include flash_attn *.hpp
# FlashAttention # FlashAttention
This repository provides the official implementation of FlashAttention from the This repository provides the official implementation of FlashAttention and
following paper. FlashAttention-2 from the
following papers.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
...@@ -8,39 +9,22 @@ Paper: https://arxiv.org/abs/2205.14135 ...@@ -8,39 +9,22 @@ Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
![FlashAttention](assets/flashattn_banner.jpg) ![FlashAttention](assets/flashattn_banner.jpg)
## Usage **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
Tri Dao
We've been very happy to see FlashAttention being widely adopted in such a short Paper: https://tridao.me/publications/flash2/flash2.pdf
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
## Full model code and training script ![FlashAttention-2](assets/flashattention_logo.png)
We have released the full GPT model
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 189
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
any activation checkpointing).
We also include a training ## Usage
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
## Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
We also have an experimental implementation in Triton that support attention We've been very happy to see FlashAttention being widely adopted in such a short
bias (e.g. ALiBi): time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
## Installation and features ## Installation and features
...@@ -53,125 +37,116 @@ We recommend the ...@@ -53,125 +37,116 @@ We recommend the
container from Nvidia, which has all the required tools to install FlashAttention. container from Nvidia, which has all the required tools to install FlashAttention.
To install: To install:
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
```sh ```sh
pip install flash-attn pip install flash-attn --no-build-isolation
``` ```
Alternatively you can compile from source: Alternatively you can compile from source:
``` ```
python setup.py install python setup.py install
``` ```
Interface: `src/flash_attention.py` Interface: `src/flash_attention_interface.py`
To run the benchmark against PyTorch standard attention:
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
FlashAttention currently supports: FlashAttention-2 currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080). 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., GPUs for now.
128). Head dim > 64 backward requires A100 or H100. 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
Our tentative roadmap:
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4. ~~[Jun 2022] Support bf16~~[Done].
5. ~~[Jul 2022] Implement cross-attention~~[Done].
6. ~~[Jul 2022] Support head dimension 128~~[Done].
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
## How to use FlashAttention ## How to use FlashAttention
Here's a simple example: The main functions implement scaled dot product attention (softmax(Q @ K^T *
```python softmax_scale) @ V):
import torch ```
from flash_attn.flash_attention import FlashMHA from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
``` ```
Alternatively, you can import the inner attention layer only (so that the input ```
and output linear layers are not included): flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
```python """dropout_p should be set to 0.0 during evaluation
from flash_attn.flash_attention import FlashAttention If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
```
# Create the nn.Module ```
flash_attention = FlashAttention() flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
``` ```
Or, if you need more fine-grained control, you can import one of the lower-level To see how these functions are used in a multi-head attention layer (which
functions (this is more similar to the `torch.nn.functional` style): includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
```python
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# or ## Upgrading from FlashAttention (1.x) to FlashAttention-2
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
# etc. If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
``` ```
```
There are also separate Python files with various FlashAttention extensions: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```python
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func
# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
``` ```
## Speedup and Memory Savings ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
We currently have benchmarks for these GPUs: We currently have benchmarks for these GPUs:
* [A100](#a100) * [A100](#a100)
* [RTX 3090](#rtx-3090) * [H100](#h100)
* [T4](#t4) <!-- * [RTX 3090](#rtx-3090) -->
<!-- * [T4](#t4) -->
### A100 ### A100
We display FlashAttention speedup using these parameters (similar to BERT-base): We display FlashAttention speedup using these parameters:
* Batch size 8 * Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
* Head dimension 64 * Sequence length 512, 1k, 2k, 4k, 8k, 16k.
* 12 attention heads * Batch size set to 16k / seqlen.
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
#### Speedup #### Speedup
![FlashAttention speedup](assets/flashattn_speedup.jpg) ![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
#### Memory #### Memory
...@@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h ...@@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h
We see 10X memory savings at sequence length 2K, and 20X at 4K. We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths. As a result, FlashAttention can scale to much longer sequence lengths.
#### Head Dimension 128 ### H100
![FlashAttention speedup, head dimension 128](assets/flashattn_speedup_a100_d128.jpg)
We show speedup with head dimension 128.
Here we show batch size 16 with 12 heads.
Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling.
But speedup is still significant, especially with a causal mask.
### RTX 3090
For the RTX 3090, we use batch size 12 with 12 attention heads. ![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)
Memory savings are the same as on an A100, so we'll only show speedup here.
![FlashAttention speedup GTX 3090](assets/flashattn_speedup_3090.jpg) ## Full model code and training script
We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
### T4 We have released the full GPT model
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 225
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
any activation checkpointing).
We again use batch size 12 with 12 attention heads. We also include a training
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
![Flashattention speedup T4](assets/flashattn_speedup_t4.jpg) ## Triton implementation of FlashAttention
T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W). Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
This matches the IO complexity analysis from section 3.2 of [our paper](https://arxiv.org/abs/2205.14135). https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above): As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
![FlashAttention speedup T4 fwd](assets/flashattn_speedup_t4_fwd.jpg) We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
We see speedups between 2.5x-4.5x on the forward pass.
## Tests ## Tests
We test that FlashAttention produces the same output and gradient as a reference We test that FlashAttention produces the same output and gradient as a reference
...@@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py ...@@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py
``` ```
## When you encounter issues ## When you encounter issues
This alpha release of FlashAttention contains code written for a research This new release of FlashAttention-2 have been tested on several GPT-style
project to validate ideas on speeding up attention. models, mostly on A100 GPUs.
We have tested it on several models (BERT, GPT2, ViT).
However, there might still be bugs in the implementation that we hope to iron
out in the next few months.
If you encounter any of these bugs, please open a respective GitHub Issue! If you encounter any of bugs, please open a respective GitHub Issue!
## Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
## Citation ## Citation
If you use this codebase, or otherwise found our work valuable, please cite: If you use this codebase, or otherwise found our work valuable, please cite:
...@@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite: ...@@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite:
booktitle={Advances in Neural Information Processing Systems}, booktitle={Advances in Neural Information Processing Systems},
year={2022} year={2022}
} }
@article{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author={Dao, Tri},
year={2023}
}
``` ```
...@@ -6,11 +6,21 @@ import torch.nn.functional as F ...@@ -6,11 +6,21 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
# from flash_attn.triton.fused_attention import attention as attention from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func # # from flash_attn.triton.fused_attention import attention as attention
from flash_attn.flash_attn_triton_og import attention as attention_og # from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
# from flash_attn.flash_attn_triton_og import attention as attention_og
# from triton.ops.flash_attention import attention as attention_triton
try:
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
except ImportError:
fav2_qkvpacked_func = None
fav2_kvpacked_func = None
try: try:
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
...@@ -71,16 +81,18 @@ def attention_megatron(qkv): ...@@ -71,16 +81,18 @@ def attention_megatron(qkv):
torch.manual_seed(0) torch.manual_seed(0)
repeats = 30 repeats = 30
batch_size = 2 batch_size = 2
seqlen = 4096 seqlen = 8192
nheads = 12 nheads = 12
headdim = 128 headdim = 128
# nheads = 24
# headdim = 64
# batch_size = 64 # batch_size = 64
# seqlen = 512 # seqlen = 512
# nheads = 8 # nheads = 8
# headdim = 128 # headdim = 128
dropout_p = 0.0 dropout_p = 0.1
causal = True causal = False
dtype = torch.bfloat16 dtype = torch.float16
device = 'cuda' device = 'cuda'
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
...@@ -88,18 +100,130 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d ...@@ -88,18 +100,130 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'), # qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
repeats=repeats, desc='PyTorch Attention') # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
# requires_grad=True)
# if fav2_kvpacked_func is not None:
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
# dropout_p = 0.0
# causal = False
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
# repeats=repeats, desc='PyTorch Attention')
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
# if scaled_upper_triang_masked_softmax is not None:
# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
# from src.ops.fftconv import fftconv_func
# dim = nheads * headdim
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
# D = torch.randn(dim, device=device, requires_grad=True)
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
# pytorch_profiler(torch.fft.rfft, u.float())
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
time_f = {}
time_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
f, b = time_fwd_bwd(
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# # Try both values of sequence_parallel and pick the faster one
# f, b = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# False, repeats=repeats, verbose=False
# )
# _, b0 = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# True, repeats=repeats, verbose=False
# )
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
if seqlen <= 8 * 1024:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
else:
f, b = float('nan'), float('nan')
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) # requires_grad=True) for _ in range(3)]
# import xformers.ops as xops
# f, b = time_fwd_bwd(
# xops.memory_efficient_attention, q, k, v,
# attn_bias=xops.LowerTriangularMask() if causal else None,
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
# )
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
if scaled_upper_triang_masked_softmax is not None: import pickle
benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') with open('flash2_attn_time_h100.plk', 'wb') as fp:
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -8,7 +8,7 @@ from einops import rearrange, repeat ...@@ -8,7 +8,7 @@ from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
...@@ -62,7 +62,7 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3, ...@@ -62,7 +62,7 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_() h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_() qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()
fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func( fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
) )
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention') benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
......
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933
Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25
This diff is collapsed.
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Random state.
at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
// template<>
// void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
// using elem_type = cutlass::bfloat16_t;
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// }
// }
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
// template<>
// void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
// using elem_type = cutlass::half_t;
// if (params.h == params.h_k) {
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 32, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
// // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// }
// }
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment