Commit 4f285b35 authored by Tri Dao's avatar Tri Dao
Browse files

FlashAttention-2 release

parent 6d48e14a
[submodule "csrc/flash_attn/cutlass"]
path = csrc/flash_attn/cutlass
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
Tri Dao, trid@stanford.edu
Dan Fu, danfu@cs.stanford.edu
\ No newline at end of file
Tri Dao, trid@cs.stanford.edu
\ No newline at end of file
......@@ -2,8 +2,10 @@ recursive-include csrc *.cu
recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include flash_attn *.cu
recursive-include flash_attn *.h
recursive-include flash_attn *.cuh
recursive-include flash_attn *.cpp
recursive-include flash_attn *.hpp
# FlashAttention
This repository provides the official implementation of FlashAttention from the
following paper.
This repository provides the official implementation of FlashAttention and
FlashAttention-2 from the
following papers.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
......@@ -8,39 +9,22 @@ Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
![FlashAttention](assets/flashattn_banner.jpg)
## Usage
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
Tri Dao
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
Paper: https://tridao.me/publications/flash2/flash2.pdf
## Full model code and training script
![FlashAttention-2](assets/flashattention_logo.png)
We have released the full GPT model
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 189
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
any activation checkpointing).
We also include a training
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
## Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
## Usage
We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
## Installation and features
......@@ -53,125 +37,116 @@ We recommend the
container from Nvidia, which has all the required tools to install FlashAttention.
To install:
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
```sh
pip install flash-attn
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```
python setup.py install
```
Interface: `src/flash_attention.py`
To run the benchmark against PyTorch standard attention:
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
Interface: `src/flash_attention_interface.py`
FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
128). Head dim > 64 backward requires A100 or H100.
Our tentative roadmap:
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4. ~~[Jun 2022] Support bf16~~[Done].
5. ~~[Jul 2022] Implement cross-attention~~[Done].
6. ~~[Jul 2022] Support head dimension 128~~[Done].
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
## How to use FlashAttention
Here's a simple example:
```python
import torch
from flash_attn.flash_attention import FlashMHA
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
```
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```
Alternatively, you can import the inner attention layer only (so that the input
and output linear layers are not included):
```python
from flash_attn.flash_attention import FlashAttention
```
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
```
# Create the nn.Module
flash_attention = FlashAttention()
```
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
```
Or, if you need more fine-grained control, you can import one of the lower-level
functions (this is more similar to the `torch.nn.functional` style):
```python
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
# or
## Upgrading from FlashAttention (1.x) to FlashAttention-2
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
# etc.
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
```
There are also separate Python files with various FlashAttention extensions:
```python
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func
# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
```
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
## Speedup and Memory Savings
## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
We currently have benchmarks for these GPUs:
* [A100](#a100)
* [RTX 3090](#rtx-3090)
* [T4](#t4)
* [H100](#h100)
<!-- * [RTX 3090](#rtx-3090) -->
<!-- * [T4](#t4) -->
### A100
We display FlashAttention speedup using these parameters (similar to BERT-base):
* Batch size 8
* Head dimension 64
* 12 attention heads
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
We display FlashAttention speedup using these parameters:
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
* Batch size set to 16k / seqlen.
#### Speedup
![FlashAttention speedup](assets/flashattn_speedup.jpg)
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)
#### Memory
......@@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.
#### Head Dimension 128
![FlashAttention speedup, head dimension 128](assets/flashattn_speedup_a100_d128.jpg)
We show speedup with head dimension 128.
Here we show batch size 16 with 12 heads.
Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling.
But speedup is still significant, especially with a causal mask.
### RTX 3090
### H100
For the RTX 3090, we use batch size 12 with 12 attention heads.
Memory savings are the same as on an A100, so we'll only show speedup here.
![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)
![FlashAttention speedup GTX 3090](assets/flashattn_speedup_3090.jpg)
We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
## Full model code and training script
### T4
We have released the full GPT model
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 225
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
any activation checkpointing).
We again use batch size 12 with 12 attention heads.
We also include a training
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
![Flashattention speedup T4](assets/flashattn_speedup_t4.jpg)
## Triton implementation of FlashAttention
T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W).
This matches the IO complexity analysis from section 3.2 of [our paper](https://arxiv.org/abs/2205.14135).
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above):
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
![FlashAttention speedup T4 fwd](assets/flashattn_speedup_t4_fwd.jpg)
We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
We see speedups between 2.5x-4.5x on the forward pass.
## Tests
We test that FlashAttention produces the same output and gradient as a reference
......@@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues
This alpha release of FlashAttention contains code written for a research
project to validate ideas on speeding up attention.
We have tested it on several models (BERT, GPT2, ViT).
However, there might still be bugs in the implementation that we hope to iron
out in the next few months.
This new release of FlashAttention-2 have been tested on several GPT-style
models, mostly on A100 GPUs.
If you encounter any of these bugs, please open a respective GitHub Issue!
## Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
If you encounter any of bugs, please open a respective GitHub Issue!
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
......@@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite:
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
@article{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author={Dao, Tri},
year={2023}
}
```
......@@ -6,11 +6,21 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
# from flash_attn.triton.fused_attention import attention as attention
from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
from flash_attn.flash_attn_triton_og import attention as attention_og
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# # from flash_attn.triton.fused_attention import attention as attention
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
# from flash_attn.flash_attn_triton_og import attention as attention_og
# from triton.ops.flash_attention import attention as attention_triton
try:
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
except ImportError:
fav2_qkvpacked_func = None
fav2_kvpacked_func = None
try:
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
......@@ -71,16 +81,18 @@ def attention_megatron(qkv):
torch.manual_seed(0)
repeats = 30
batch_size = 2
seqlen = 4096
seqlen = 8192
nheads = 12
headdim = 128
# nheads = 24
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p = 0.0
causal = True
dtype = torch.bfloat16
dropout_p = 0.1
causal = False
dtype = torch.float16
device = 'cuda'
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
......@@ -88,18 +100,130 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'),
cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
repeats=repeats, desc='PyTorch Attention')
# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
# requires_grad=True)
# if fav2_kvpacked_func is not None:
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
# dropout_p = 0.0
# causal = False
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
# repeats=repeats, desc='PyTorch Attention')
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
# if scaled_upper_triang_masked_softmax is not None:
# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
# from src.ops.fftconv import fftconv_func
# dim = nheads * headdim
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
# D = torch.randn(dim, device=device, requires_grad=True)
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
# pytorch_profiler(torch.fft.rfft, u.float())
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
time_f = {}
time_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
f, b = time_fwd_bwd(
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# # Try both values of sequence_parallel and pick the faster one
# f, b = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# False, repeats=repeats, verbose=False
# )
# _, b0 = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# True, repeats=repeats, verbose=False
# )
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
if seqlen <= 8 * 1024:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
else:
f, b = float('nan'), float('nan')
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# import xformers.ops as xops
# f, b = time_fwd_bwd(
# xops.memory_efficient_attention, q, k, v,
# attn_bias=xops.LowerTriangularMask() if causal else None,
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
# )
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
if scaled_upper_triang_masked_softmax is not None:
benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
import pickle
with open('flash2_attn_time_h100.plk', 'wb') as fp:
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
......@@ -8,7 +8,7 @@ from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
......@@ -62,7 +62,7 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()
fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
)
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
......
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933
Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25
This diff is collapsed.
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Random state.
at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
// template<>
// void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
// using elem_type = cutlass::bfloat16_t;
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// }
// }
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
// template<>
// void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
// using elem_type = cutlass::half_t;
// if (params.h == params.h_k) {
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 32, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
// // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
// }
// }
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment