Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
083e8f52
Commit
083e8f52
authored
Sep 24, 2023
by
Tri Dao
Browse files
Implement local attention
Co-authored-by:
Timothee Lacroix
<
t@mistral.ai
>
parent
4c8ff915
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
706 additions
and
255 deletions
+706
-255
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+45
-13
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+3
-0
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+59
-7
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+16
-12
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+69
-45
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+40
-34
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+17
-6
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+159
-25
tests/test_flash_attn.py
tests/test_flash_attn.py
+298
-113
No files found.
csrc/flash_attn/flash_api.cpp
View file @
083e8f52
...
...
@@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
void
*
softmax_lse_d
,
float
p_dropout
,
float
softmax_scale
,
bool
is_causal
)
{
int
window_size_left
,
int
window_size_right
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params ¶ms,
params
.
scale_softmax_rp_dropout
=
params
.
rp_dropout
*
params
.
scale_softmax
;
TORCH_CHECK
(
p_dropout
<
1.
f
);
params
.
is_causal
=
is_causal
;
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params
.
is_causal
=
window_size_left
<
0
&&
window_size_right
==
0
;
if
(
window_size_left
<
0
&&
window_size_right
>=
0
)
{
window_size_left
=
seqlen_k
;
}
if
(
window_size_left
>=
0
&&
window_size_right
<
0
)
{
window_size_right
=
seqlen_k
;
}
params
.
window_size_left
=
window_size_left
;
params
.
window_size_right
=
window_size_right
;
params
.
is_seqlens_k_cumulative
=
true
;
}
...
...
@@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
void
*
dsoftmax_sum_d
,
float
p_dropout
,
float
softmax_scale
,
bool
is_causal
)
{
int
window_size_left
,
int
window_size_right
)
{
set_params_fprop
(
params
,
b
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
h
,
h_k
,
d
,
d_rounded
,
...
...
@@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
softmax_lse_d
,
p_dropout
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// Set the pointers and strides.
params
.
do_ptr
=
dout
.
data_ptr
();
...
...
@@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
float
p_dropout
,
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
;
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
;
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse
.
data_ptr
(),
p_dropout
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
...
...
@@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
...
@@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse
.
data_ptr
(),
p_dropout
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
...
...
@@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
...
@@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d
.
data_ptr
(),
p_dropout
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
...
...
@@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
...
@@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d
.
data_ptr
(),
p_dropout
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
...
...
@@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
)
{
...
...
@@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
head_size_og
%
8
==
0
;
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
;
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
...
...
csrc/flash_attn/src/flash.h
View file @
083e8f52
...
...
@@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params {
float
rp_dropout
;
float
scale_softmax_rp_dropout
;
// Local window size
int
window_size_left
,
window_size_right
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
083e8f52
...
...
@@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1colblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
n_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
int
m_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM
);
if
(
Is_local
)
{
m_block_max
=
std
::
min
(
m_block_max
,
cute
::
ceil_div
((
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
+
params
.
window_size_left
,
kBlockM
));
}
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
(
m_block_max
-
1
)
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
...
...
@@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
h
*
params
.
d_rounded
;
int
m_block
=
m_block_max
-
1
;
int
m_block_min
=
!
Is_causal
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
)
/
kBlockM
);
// We're guaranteed that m_block_min <= m_block:
int
m_block_min
=
(
!
Is_causal
&&
!
Is_local
)
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
-
params
.
window_size_right
)
/
kBlockM
);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// However, if local, then this possible to have some blocks of K & V not attending to any query.
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
if
(
Is_local
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dv_row_stride
+
bidh
*
params
.
dv_head_stride
;
Tensor
gdK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dk_ptr
)
+
row_offset_dk
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dk_row_stride
,
_1
{}));
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
clear
(
tdKrdK
);
clear
(
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gdK
),
size
<
1
>
(
gdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
}
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
...
...
@@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
);
}
}
else
{
}
else
if
(
Is_causal
)
{
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
...
...
@@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
}
}
else
if
(
Is_local
)
{
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
-
params
.
window_size_right
||
(
m_block
+
1
)
*
kBlockM
>=
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
+
params
.
window_size_left
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
...
...
@@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_seqk_parallel
(
const
Params
&
params
)
{
const
int
n_block
=
blockIdx
.
x
;
...
...
@@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
;
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
083e8f52
...
...
@@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
>
...
...
@@ -62,16 +63,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
&&
IsEvenKConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
kernel
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
kernel
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
083e8f52
...
...
@@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
if
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
0
)
{
if
(
n_block_max
<=
n_block_min
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
...
...
@@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
return
;
}
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
...
...
@@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
:
(
(
Is_even_MN
&&
Is_causal
)
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
...
@@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print
_tensor
(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
...
...
@@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
flash
::
apply_mask_local
<
/*HasWSLeft=*/
Is_local
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
0
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
@@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
...
...
@@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if
(
n_masking_steps
>
1
&&
n_block
<=
0
)
{
if
(
n_masking_steps
>
1
&&
n_block
<=
n_block_min
)
{
--
n_block
;
break
;
}
}
// These are the iterations where we don't need masking on S
for
(;
n_block
>=
0
;
--
n_block
)
{
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
...
...
@@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
0
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
@@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
...
@@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
n_split_idx
*
n_blocks_per_split
;
const
int
n_block_min
=
!
Is_local
?
n_split_idx
*
n_blocks_per_split
:
std
::
max
(
n_split_idx
*
n_blocks_per_split
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
std
::
min
(
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
),
(
n_split_idx
+
1
)
*
n_blocks_per_split
);
if
(
Is_causal
)
{
if
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
}
if
(
n_block_min
>=
n_block_max
)
{
// This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
...
...
@@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSin
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gCosCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
...
...
@@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
:
(
(
Is_even_MN
&&
Is_causal
)
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
...
@@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
flash
::
cp_async_wait
<
0
>
();
...
...
@@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We have key_padding_mask so we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
...
...
@@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
...
@@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
083e8f52
...
...
@@ -10,14 +10,15 @@
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
...
...
@@ -42,23 +43,25 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
&&
IsEvenKConst
&&
(
!
ReturnSoftmaxConst
)
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
...
...
@@ -76,19 +79,22 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
...
...
csrc/flash_attn/src/softmax.h
View file @
083e8f52
...
...
@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
...
...
@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit
)
{
if
(
col_idx
>=
col_idx_limit
_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
)
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
...
...
@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset_
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
...
...
flash_attn/flash_attn_interface.py
View file @
083e8f52
...
...
@@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return
(
128
,
64
)
if
is_sm80
else
(
64
,
64
)
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
None
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
None
q
,
k
,
v
,
None
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
return_softmax
,
None
,
)
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
...
...
@@ -61,6 +71,7 @@ def _flash_attn_varlen_forward(
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
,
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
...
...
@@ -78,6 +89,8 @@ def _flash_attn_varlen_forward(
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
return_softmax
,
None
,
)
...
...
@@ -87,7 +100,20 @@ def _flash_attn_varlen_forward(
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
rng_state
=
None
,
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
...
...
@@ -105,6 +131,8 @@ def _flash_attn_backward(
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
None
,
rng_state
,
)
...
...
@@ -128,6 +156,7 @@ def _flash_attn_varlen_backward(
dropout_p
,
softmax_scale
,
causal
,
window_size
,
rng_state
=
None
,
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
...
...
@@ -151,6 +180,8 @@ def _flash_attn_varlen_backward(
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
None
,
rng_state
,
)
...
...
@@ -161,7 +192,7 @@ def _flash_attn_varlen_backward(
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
...
@@ -171,12 +202,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -197,15 +230,26 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
...
...
@@ -219,6 +263,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
...
...
@@ -226,6 +271,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen
=
max_seqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -250,15 +296,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
...
@@ -268,12 +315,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -295,11 +344,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -315,6 +365,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
,
):
if
softmax_scale
is
None
:
...
...
@@ -330,6 +381,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
...
...
@@ -340,6 +392,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -365,16 +418,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
...
@@ -384,12 +438,14 @@ class FlashAttnFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -409,12 +465,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -431,6 +488,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
,
):
if
softmax_scale
is
None
:
...
...
@@ -446,6 +504,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
...
...
@@ -456,6 +515,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -479,16 +539,22 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -542,6 +620,10 @@ def flash_attn_kvpacked_func(
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
...
...
@@ -549,6 +631,7 @@ def flash_attn_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -561,11 +644,20 @@ def flash_attn_kvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...
...
@@ -585,6 +677,10 @@ def flash_attn_func(
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
...
...
@@ -593,6 +689,7 @@ def flash_attn_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -605,7 +702,9 @@ def flash_attn_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_varlen_qkvpacked_func
(
...
...
@@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func(
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
...
...
@@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnVarlenQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
,
)
...
...
@@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func(
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
,
)
...
...
@@ -733,6 +851,7 @@ def flash_attn_varlen_func(
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -753,6 +872,10 @@ def flash_attn_varlen_func(
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -767,6 +890,7 @@ def flash_attn_varlen_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -790,6 +914,7 @@ def flash_attn_varlen_func(
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
,
)
...
...
@@ -805,6 +930,7 @@ def flash_attn_with_kvcache(
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
num_splits
=
0
,
):
...
...
@@ -818,11 +944,12 @@ def flash_attn_with_kvcache(
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
...
...
@@ -843,6 +970,10 @@ def flash_attn_with_kvcache(
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
...
...
@@ -860,6 +991,7 @@ def flash_attn_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...
...
@@ -894,6 +1026,8 @@ def flash_attn_with_kvcache(
None
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
rotary_interleaved
,
num_splits
,
)
...
...
tests/test_flash_attn.py
View file @
083e8f52
...
...
@@ -150,8 +150,13 @@ def generate_qkv(
)
def
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
...
...
@@ -165,7 +170,14 @@ def construct_causal_mask(
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
return
col_idx
>
row_idx
+
sk
-
sq
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
...
...
@@ -177,6 +189,7 @@ def attention_ref(
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
...
...
@@ -189,6 +202,8 @@ def attention_ref(
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...
...
@@ -198,6 +213,8 @@ def attention_ref(
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
...
...
@@ -211,17 +228,24 @@ def attention_ref(
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
caus
al_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
loc
al_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
if
causal
:
# Some rows are completely masked out so we fill them with zero instead of NaN
attention
=
attention
.
masked_fill
(
torch
.
all
(
causal_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
...
...
@@ -232,7 +256,6 @@ def attention_ref(
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
...
...
@@ -244,6 +267,7 @@ def attention_kvpacked_ref(
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
...
...
@@ -257,6 +281,7 @@ def attention_kvpacked_ref(
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
...
...
@@ -267,6 +292,7 @@ def attention_qkvpacked_ref(
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
...
...
@@ -280,6 +306,7 @@ def attention_qkvpacked_ref(
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
...
...
@@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
def
convert_flash_attn_S_to_softmax
(
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
...
...
@@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax(
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
...
...
@@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax(
four
=
4
,
)
if
causal
:
# causal_mask = torch.triu(
# torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
S
.
device
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
S
.
device
,
)
caus
al_mask
=
F
.
pad
(
caus
al_mask
,
loc
al_mask
=
F
.
pad
(
loc
al_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
)
S_converted
.
masked_fill_
(
caus
al_mask
,
0.0
)
S_converted
.
masked_fill_
(
loc
al_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
...
...
@@ -399,6 +438,7 @@ def normalize_flash_attn_S(
key_padding_mask
=
None
,
is_dropout
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
Arguments:
...
...
@@ -409,20 +449,24 @@ def normalize_flash_attn_S(
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
seqlen_k
=
k
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
caus
al_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
loc
al_mask
,
float
(
"-inf"
))
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
...
...
@@ -446,79 +490,84 @@ def normalize_flash_attn_S(
def
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
dropped
=
~
dropout_mask
valid
=
torch
.
ones_like
(
dropout_mask
)
if
query_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
causal
:
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
valid
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
,
)
dropped
.
masked_fill_
(
causal_mask
,
False
)
dropped
.
masked_fill_
(
local_mask
,
False
)
valid
.
masked_fill_
(
local_mask
,
False
)
dropped_total
=
dropped
.
sum
()
query_lengths
=
(
query_padding_mask
.
sum
(
dim
=-
1
)
if
query_padding_mask
is
not
None
else
torch
.
full
((
batch_size
,),
seqlen_q
,
device
=
dropout_mask
.
device
)
)
key_lengths
=
(
key_padding_mask
.
sum
(
dim
=-
1
)
if
key_padding_mask
is
not
None
else
torch
.
full
((
batch_size
,),
seqlen_k
,
device
=
dropout_mask
.
device
)
)
if
not
causal
:
numel_per_batch
=
query_lengths
*
key_lengths
else
:
numel_per_batch
=
torch
.
where
(
key_lengths
<=
query_lengths
,
key_lengths
*
(
key_lengths
+
1
)
/
2
,
query_lengths
*
key_lengths
-
(
query_lengths
*
(
query_lengths
-
1
)
/
2
),
)
return
dropped_total
/
(
numel_per_batch
.
sum
()
*
nheads
)
return
dropped
.
sum
()
/
valid
.
sum
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize(
'
causal
'
, [False])
# @pytest.mark.parametrize(
"
causal
"
, [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize(
'd'
, [64])
# @pytest.mark.parametrize(
"d"
, [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize(
'
seqlen
'
, [128])
# @pytest.mark.parametrize(
"
seqlen
"
, [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize(
'
dropout_p
'
, [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
# @pytest.mark.parametrize(
"
dropout_p
"
, [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
...
...
@@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
None
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
).
item
()
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
...
...
@@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
...
@@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
...
...
@@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
nheads
=
6
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
...
@@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
...
...
@@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
...
...
@@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
key_padding_mask
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
...
...
@@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
...
@@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
...
...
@@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
mha_type
,
dtype
,
kvpacked
):
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
...
@@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
kvpacked
:
kv
=
torch
.
randn
(
...
...
@@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if
kvpacked
:
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
else
:
out
,
lse
,
S_dmask
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
...
...
@@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
k_rep
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_rep
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k_rep
,
v_rep
,
None
,
None
,
dropout_p
>
0.0
,
causal
=
causal
attn_unnorm
,
q
,
k_rep
,
v_rep
,
None
,
None
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
).
item
()
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
kv
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
...
...
@@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
else
:
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
k
,
v
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
...
@@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
...
...
@@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output(
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
kvpacked
:
kv
=
torch
.
randn
(
...
...
@@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
window_size
=
window_size
,
)
else
:
(
...
...
@@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
window_size
=
window_size
,
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
...
...
@@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output(
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
...
...
@@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output(
key_padding_mask
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
dropout_mask
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
,
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
...
...
@@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output(
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
...
...
@@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
else
:
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output(
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
...
@@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
...
...
@@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output(
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
...
@@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
...
...
@@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [
128
])
# @pytest.mark.parametrize("d", [
64
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
...
@@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
...
@@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
max_seqlen_k
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
...
...
@@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
...
...
@@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
nheads
=
12
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
return_attn_probs
=
True
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
...
...
@@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
...
...
@@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
...
...
@@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache(
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
new_kv
,
mha_type
,
num_splits
,
...
...
@@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache(
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
...
...
@@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache(
cache_seqlens
=
torch
.
randint
(
0
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
causal
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,),
...
...
@@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache(
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
:
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
...
...
@@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache(
sin
,
cache_seqlens
,
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
num_splits
=
num_splits
,
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
...
...
@@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache(
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
_
=
attention_ref
(
q_ro
,
...
...
@@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache(
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment