Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5a834254
Commit
5a834254
authored
Oct 08, 2023
by
Tri Dao
Browse files
Change constexpr int to constexpr static int
parent
3a9fe7b0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
23 deletions
+23
-23
README.md
README.md
+4
-4
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+8
-8
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+11
-11
No files found.
README.md
View file @
5a834254
...
@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
...
@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
## Changelog
## Changelog
### 2.0
### 2.0
: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed:
These functions have been renamed:
...
@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
...
@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
python
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
```
### 2.1
### 2.1
: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
bottom right corner of the attention matrix, instead of the top-left corner.
...
@@ -243,7 +243,7 @@ v2.1:
...
@@ -243,7 +243,7 @@ v2.1:
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
### 2.2
### 2.2
: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence
Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
length (e.g., query sequence length = 1). The bottleneck here is to load KV
...
@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
...
@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
Thanks to the xformers team, and in particular Daniel Haziza, for this
Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.
collaboration.
### 2.3
### 2.3
: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to
[
Mistral
Implement sliding window attention (i.e., local attention). Thanks to
[
Mistral
AI
](
https://mistral.ai/
)
and in particular Timothée Lacroix for this
AI
](
https://mistral.ai/
)
and in particular Timothée Lacroix for this
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5a834254
...
@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
...
@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
...
@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
...
@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
...
@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
...
@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
...
@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
...
@@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
});
});
...
@@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
...
@@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5a834254
...
@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// We want kBlockM to be as small as possible for more parallelism.
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
constexpr
static
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
kBlockM
-
1
)
/
kBlockM
);
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
kBlockM
-
1
)
/
kBlockM
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
if
(
params
.
num_splits
<=
2
)
{
...
@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
,
int
Headdim
>
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
kBlockM
=
64
;
// Fixed for all head dimensions
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
...
@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
...
@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
...
@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
int
max_smem_per_block
;
...
@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
cudaGetDevice
(
&
device
);
cudaGetDevice
(
&
device
);
int
max_smem_per_sm
,
max_smem_per_block
;
int
max_smem_per_sm
,
max_smem_per_block
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment