Commit 5a834254 authored by Tri Dao's avatar Tri Dao
Browse files

Change constexpr int to constexpr static int

parent 3a9fe7b0
...@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https: ...@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
## Changelog ## Changelog
### 2.0 ### 2.0: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2 Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed: These functions have been renamed:
...@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ...@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
``` ```
### 2.1 ### 2.1: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner. bottom right corner of the attention matrix, instead of the top-left corner.
...@@ -243,7 +243,7 @@ v2.1: ...@@ -243,7 +243,7 @@ v2.1:
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
### 2.2 ### 2.2: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV length (e.g., query sequence length = 1). The bottleneck here is to load KV
...@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference ...@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
Thanks to the xformers team, and in particular Daniel Haziza, for this Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration. collaboration.
### 2.3 ### 2.3: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to [Mistral Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this AI](https://mistral.ai/) and in particular Timothée Lacroix for this
......
...@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool con ...@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool con
template<typename T> template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 32; constexpr static int Headdim = 32;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T> template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 64; constexpr static int Headdim = 64;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T> template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 96; constexpr static int Headdim = 96;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
template<typename T> template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 128; constexpr static int Headdim = 128;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T> template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 160; constexpr static int Headdim = 160;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T> template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 192; constexpr static int Headdim = 192;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T> template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 224; constexpr static int Headdim = 224;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
}); });
...@@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bo
template<typename T> template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
constexpr int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
......
...@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// We want kBlockM to be as small as possible for more parallelism. // We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc. // If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) { if (params.num_splits <= 2) {
...@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim> template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int kBlockM = 64; // Fixed for all head dimensions constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128. // and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition. // Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream); run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
} }
template<typename T> template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32; constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
...@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
...@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96; constexpr static int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
...@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128; constexpr static int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
...@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160; constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
...@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192; constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
...@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224; constexpr static int Headdim = 224;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
...@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block; int max_smem_per_sm, max_smem_per_block;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment