Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5a834254
Commit
5a834254
authored
Oct 08, 2023
by
Tri Dao
Browse files
Change constexpr int to constexpr static int
parent
3a9fe7b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
23 deletions
+23
-23
README.md
README.md
+4
-4
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+8
-8
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+11
-11
No files found.
README.md
View file @
5a834254
...
...
@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
## Changelog
### 2.0
### 2.0
: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed:
...
...
@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
### 2.1
### 2.1
: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
...
...
@@ -243,7 +243,7 @@ v2.1:
1 1
If the row of the mask is all zero, the output will be zero.
### 2.2
### 2.2
: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
...
...
@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.
### 2.3
### 2.3
: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to
[
Mistral
AI
](
https://mistral.ai/
)
and in particular Timothée Lacroix for this
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5a834254
...
...
@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
template
<
typename
T
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
template
<
typename
T
>
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
});
...
...
@@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
template
<
typename
T
>
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5a834254
...
...
@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
constexpr
static
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
kBlockM
-
1
)
/
kBlockM
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
...
...
@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
kBlockM
=
64
;
// Fixed for all head dimensions
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
...
...
@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
...
...
@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
...
@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
...
@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
...
...
@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
...
...
@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
...
...
@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
T
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_sm
,
max_smem_per_block
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment