Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
083e8f52
Commit
083e8f52
authored
Sep 24, 2023
by
Tri Dao
Browse files
Implement local attention
Co-authored-by:
Timothee Lacroix
<
t@mistral.ai
>
parent
4c8ff915
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
706 additions
and
255 deletions
+706
-255
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+45
-13
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+3
-0
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+59
-7
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+16
-12
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+69
-45
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+40
-34
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+17
-6
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+159
-25
tests/test_flash_attn.py
tests/test_flash_attn.py
+298
-113
No files found.
csrc/flash_attn/flash_api.cpp
View file @
083e8f52
...
@@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
...
@@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
void
*
softmax_lse_d
,
void
*
softmax_lse_d
,
float
p_dropout
,
float
p_dropout
,
float
softmax_scale
,
float
softmax_scale
,
bool
is_causal
)
{
int
window_size_left
,
int
window_size_right
)
{
// Reset the parameters
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
memset
(
&
params
,
0
,
sizeof
(
params
));
...
@@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params ¶ms,
...
@@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params ¶ms,
params
.
scale_softmax_rp_dropout
=
params
.
rp_dropout
*
params
.
scale_softmax
;
params
.
scale_softmax_rp_dropout
=
params
.
rp_dropout
*
params
.
scale_softmax
;
TORCH_CHECK
(
p_dropout
<
1.
f
);
TORCH_CHECK
(
p_dropout
<
1.
f
);
params
.
is_causal
=
is_causal
;
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params
.
is_causal
=
window_size_left
<
0
&&
window_size_right
==
0
;
if
(
window_size_left
<
0
&&
window_size_right
>=
0
)
{
window_size_left
=
seqlen_k
;
}
if
(
window_size_left
>=
0
&&
window_size_right
<
0
)
{
window_size_right
=
seqlen_k
;
}
params
.
window_size_left
=
window_size_left
;
params
.
window_size_right
=
window_size_right
;
params
.
is_seqlens_k_cumulative
=
true
;
params
.
is_seqlens_k_cumulative
=
true
;
}
}
...
@@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
...
@@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
void
*
dsoftmax_sum_d
,
void
*
dsoftmax_sum_d
,
float
p_dropout
,
float
p_dropout
,
float
softmax_scale
,
float
softmax_scale
,
bool
is_causal
)
{
int
window_size_left
,
int
window_size_right
)
{
set_params_fprop
(
params
,
set_params_fprop
(
params
,
b
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
h
,
h_k
,
d
,
d_rounded
,
b
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
h
,
h_k
,
d
,
d_rounded
,
...
@@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
...
@@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
softmax_lse_d
,
softmax_lse_d
,
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// Set the pointers and strides.
// Set the pointers and strides.
params
.
do_ptr
=
dout
.
data_ptr
();
params
.
do_ptr
=
dout
.
data_ptr
();
...
@@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
float
p_dropout
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
;
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
;
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// This needs to match with run_mha_fwd_splitkv_dispatch
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
...
@@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
@@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
...
@@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
@@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d
.
data_ptr
(),
softmax_d
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
auto
launch
=
&
run_mha_bwd
;
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
// launch(params, stream, /*configure=*/true);
...
@@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
)
{
if
(
is_causal
)
{
window_size_right
=
0
;
}
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
...
@@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d
.
data_ptr
(),
softmax_d
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
auto
launch
=
&
run_mha_bwd
;
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
// launch(params, stream, /*configure=*/true);
...
@@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
int
num_splits
)
{
)
{
...
@@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
head_size_og
%
8
==
0
;
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
;
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_lse
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
/*p_dropout=*/
0.
f
,
softmax_scale
,
softmax_scale
,
is_causal
);
window_size_left
,
window_size_right
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
if
(
k_
.
has_value
())
{
...
...
csrc/flash_attn/src/flash.h
View file @
083e8f52
...
@@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params {
float
rp_dropout
;
float
rp_dropout
;
float
scale_softmax_rp_dropout
;
float
scale_softmax_rp_dropout
;
// Local window size
int
window_size_left
,
window_size_right
;
// Random state.
// Random state.
at
::
PhiloxCudaState
philox_args
;
at
::
PhiloxCudaState
philox_args
;
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
083e8f52
...
@@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1colblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
n_block
)
{
inline
__device__
void
compute_dq_dk_dv_1colblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
n_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
int
m_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM
);
int
m_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM
);
if
(
Is_local
)
{
m_block_max
=
std
::
min
(
m_block_max
,
cute
::
ceil_div
((
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
+
params
.
window_size_left
,
kBlockM
));
}
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
(
m_block_max
-
1
)
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
+
(
m_block_max
-
1
)
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
...
@@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
h
*
params
.
d_rounded
;
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
h
*
params
.
d_rounded
;
int
m_block
=
m_block_max
-
1
;
int
m_block
=
m_block_max
-
1
;
int
m_block_min
=
!
Is_causal
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
)
/
kBlockM
);
int
m_block_min
=
(
!
Is_causal
&&
!
Is_local
)
// We're guaranteed that m_block_min <= m_block:
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
-
params
.
window_size_right
)
/
kBlockM
);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// However, if local, then this possible to have some blocks of K & V not attending to any query.
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
if
(
Is_local
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dv_row_stride
+
bidh
*
params
.
dv_head_stride
;
Tensor
gdK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dk_ptr
)
+
row_offset_dk
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dk_row_stride
,
_1
{}));
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
clear
(
tdKrdK
);
clear
(
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gdK
),
size
<
1
>
(
gdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
}
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
...
@@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// However, it's possible that the values in acc_s are so large that they overflow
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
// So we need to mask out the elements beyond actual_seqlen_k.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
)
{
if
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
);
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
);
}
}
}
else
{
}
else
if
(
Is_causal
)
{
// Putting this causal masking right after acc_s is *much* slower for some reason.
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
...
@@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
AtomLayoutMS
*
16
);
}
}
}
else
if
(
Is_local
)
{
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
-
params
.
window_size_right
||
(
m_block
+
1
)
*
kBlockM
>=
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
+
params
.
window_size_left
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
}
}
// if (cute::thread(32, 0)) { print(scores); }
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
// Compute the exponential value.
...
@@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
...
@@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_seqk_parallel
(
const
Params
&
params
)
{
inline
__device__
void
compute_dq_dk_dv_seqk_parallel
(
const
Params
&
params
)
{
const
int
n_block
=
blockIdx
.
x
;
const
int
n_block
=
blockIdx
.
x
;
...
@@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
...
@@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
// The block index for the head.
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
;
const
int
bidh
=
blockIdx
.
z
;
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
083e8f52
...
@@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
...
@@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
>
(
params
);
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
>
...
@@ -62,9 +63,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -62,9 +63,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
&&
IsEvenKConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
...
@@ -75,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -75,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
});
});
});
});
});
});
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
083e8f52
...
@@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem(
...
@@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
if
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// }
// We exit early and write 0 to gO and gLSE.
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
0
)
{
if
(
n_block_max
<=
n_block_min
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
...
@@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
return
;
return
;
}
}
}
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
...
@@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
:
(
(
Is_even_MN
&&
Is_causal
)
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
@@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print
_tensor
(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
...
@@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask_local
<
/*HasWSLeft=*/
Is_local
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
);
params
.
window_size_left
,
params
.
window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
}
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
(
n_block
>
0
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
@@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
...
@@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
// This check is at the end of the loop since we always have at least 1 iteration
if
(
n_masking_steps
>
1
&&
n_block
<=
0
)
{
if
(
n_masking_steps
>
1
&&
n_block
<=
n_block_min
)
{
--
n_block
;
--
n_block
;
break
;
break
;
}
}
}
}
// These are the iterations where we don't need masking on S
// These are the iterations where we don't need masking on S
for
(;
n_block
>=
0
;
--
n_block
)
{
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
...
@@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
(
n_block
>
0
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
@@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
@@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
n_split_idx
*
n_blocks_per_split
;
const
int
n_block_min
=
!
Is_local
?
n_split_idx
*
n_blocks_per_split
:
std
::
max
(
n_split_idx
*
n_blocks_per_split
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
std
::
min
(
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
),
(
n_split_idx
+
1
)
*
n_blocks_per_split
);
int
n_block_max
=
std
::
min
(
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
),
(
n_split_idx
+
1
)
*
n_blocks_per_split
);
if
(
Is_causal
)
{
if
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
}
}
if
(
n_block_min
>=
n_block_max
)
{
// This also covers the case where n_block_max <= 0
if
(
n_block_min
>=
n_block_max
)
{
// This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
...
@@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
}
else
{
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSin
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Tensor
gSin
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gCosCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Tensor
gCosCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
make_stride
(
Is_causal
||
Is_local
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
...
@@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
:
(
(
Is_even_MN
&&
Is_causal
)
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
@@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
}
else
{
flash
::
apply_mask_
caus
al
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask_
loc
al
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
);
params
.
window_size_left
,
params
.
window_size_right
);
}
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
...
@@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We have key_padding_mask so we'll need to Check_inf
// We have key_padding_mask so we'll need to Check_inf
masking_step
==
0
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
// Convert scores from fp32 to fp16/bf16
...
@@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
@@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
...
@@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
083e8f52
...
@@ -10,14 +10,15 @@
...
@@ -10,14 +10,15 @@
#include "flash.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
...
@@ -42,13 +43,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -42,13 +43,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
&&
IsEvenKConst
&&
(
!
ReturnSoftmaxConst
)
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// If Is_local, set Is_causal to false
//
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal
, IsEvenMNConst, true
, ReturnSoftmaxConst && Is_dropout>;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
...
@@ -62,6 +64,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -62,6 +64,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
});
});
});
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
...
@@ -76,11 +79,13 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -76,11 +79,13 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
&&
!
Is_local
,
Is_local
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
...
@@ -94,6 +99,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -94,6 +99,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
});
});
});
if
(
params
.
num_splits
>
1
)
{
if
(
params
.
num_splits
>
1
)
{
// We want kBlockM to be as small as possible for more parallelism.
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
...
...
csrc/flash_attn/src/softmax.h
View file @
083e8f52
...
@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
...
@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
}
}
}
}
template
<
typename
Engine
,
typename
Layout
>
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_
caus
al
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
inline
__device__
void
apply_mask_
loc
al
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
...
@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
...
@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit
)
{
if
(
col_idx
>=
col_idx_limit
_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
)
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
...
@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
...
@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
}
}
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset_
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
...
...
flash_attn/flash_attn_interface.py
View file @
083e8f52
...
@@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
...
@@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return
(
128
,
64
)
if
is_sm80
else
(
64
,
64
)
return
(
128
,
64
)
if
is_sm80
else
(
64
,
64
)
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
None
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
None
q
,
k
,
v
,
None
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
return_softmax
,
None
,
)
)
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
...
@@ -61,6 +71,7 @@ def _flash_attn_varlen_forward(
...
@@ -61,6 +71,7 @@ def _flash_attn_varlen_forward(
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
return_softmax
,
return_softmax
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
...
@@ -78,6 +89,8 @@ def _flash_attn_varlen_forward(
...
@@ -78,6 +89,8 @@ def _flash_attn_varlen_forward(
softmax_scale
,
softmax_scale
,
False
,
False
,
causal
,
causal
,
window_size
[
0
],
window_size
[
1
],
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -87,7 +100,20 @@ def _flash_attn_varlen_forward(
...
@@ -87,7 +100,20 @@ def _flash_attn_varlen_forward(
def
_flash_attn_backward
(
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
rng_state
=
None
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
...
@@ -105,6 +131,8 @@ def _flash_attn_backward(
...
@@ -105,6 +131,8 @@ def _flash_attn_backward(
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
1
],
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -128,6 +156,7 @@ def _flash_attn_varlen_backward(
...
@@ -128,6 +156,7 @@ def _flash_attn_varlen_backward(
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
rng_state
=
None
,
rng_state
=
None
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
...
@@ -151,6 +180,8 @@ def _flash_attn_varlen_backward(
...
@@ -151,6 +180,8 @@ def _flash_attn_varlen_backward(
softmax_scale
,
softmax_scale
,
False
,
False
,
causal
,
causal
,
window_size
[
0
],
window_size
[
1
],
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -161,7 +192,7 @@ def _flash_attn_varlen_backward(
...
@@ -161,7 +192,7 @@ def _flash_attn_varlen_backward(
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
@@ -171,12 +202,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -171,12 +202,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -197,15 +230,26 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -197,15 +230,26 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
,
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
...
@@ -219,6 +263,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -219,6 +263,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
...
@@ -226,6 +271,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -226,6 +271,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen
=
max_seqlen
ctx
.
max_seqlen
=
max_seqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -250,15 +296,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -250,15 +296,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
@@ -268,12 +315,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -268,12 +315,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -295,11 +344,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -295,11 +344,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
...
@@ -315,6 +365,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -315,6 +365,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
return_softmax
,
return_softmax
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
...
@@ -330,6 +381,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -330,6 +381,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
ctx
.
save_for_backward
(
...
@@ -340,6 +392,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -340,6 +392,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -365,16 +418,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -365,16 +418,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
...
@@ -384,12 +438,14 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -384,12 +438,14 @@ class FlashAttnFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -409,12 +465,13 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -409,12 +465,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
...
@@ -431,6 +488,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -431,6 +488,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
return_softmax
,
return_softmax
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
...
@@ -446,6 +504,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -446,6 +504,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
ctx
.
save_for_backward
(
...
@@ -456,6 +515,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -456,6 +515,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -479,16 +539,22 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -479,16 +539,22 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
dropout_p
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func(
...
@@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func(
...
@@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_kvpacked_func
(
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
If K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -542,6 +620,10 @@ def flash_attn_kvpacked_func(
...
@@ -542,6 +620,10 @@ def flash_attn_kvpacked_func(
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
Arguments:
q: (batch_size, seqlen, nheads, headdim)
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
...
@@ -549,6 +631,7 @@ def flash_attn_kvpacked_func(
...
@@ -549,6 +631,7 @@ def flash_attn_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -561,11 +644,20 @@ def flash_attn_kvpacked_func(
...
@@ -561,11 +644,20 @@ def flash_attn_kvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_func
(
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...
@@ -585,6 +677,10 @@ def flash_attn_func(
...
@@ -585,6 +677,10 @@ def flash_attn_func(
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
Arguments:
q: (batch_size, seqlen, nheads, headdim)
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
...
@@ -593,6 +689,7 @@ def flash_attn_func(
...
@@ -593,6 +689,7 @@ def flash_attn_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -605,7 +702,9 @@ def flash_attn_func(
...
@@ -605,7 +702,9 @@ def flash_attn_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
)
def
flash_attn_varlen_qkvpacked_func
(
def
flash_attn_varlen_qkvpacked_func
(
...
@@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
...
@@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
...
@@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnVarlenQKVPackedFunc
.
apply
(
return
FlashAttnVarlenQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
return_attn_probs
,
)
)
...
@@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
...
@@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func(
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
@@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
return_attn_probs
,
return_attn_probs
,
)
)
...
@@ -733,6 +851,7 @@ def flash_attn_varlen_func(
...
@@ -733,6 +851,7 @@ def flash_attn_varlen_func(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
return_attn_probs
=
False
,
return_attn_probs
=
False
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
...
@@ -753,6 +872,10 @@ def flash_attn_varlen_func(
...
@@ -753,6 +872,10 @@ def flash_attn_varlen_func(
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
@@ -767,6 +890,7 @@ def flash_attn_varlen_func(
...
@@ -767,6 +890,7 @@ def flash_attn_varlen_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -790,6 +914,7 @@ def flash_attn_varlen_func(
...
@@ -790,6 +914,7 @@ def flash_attn_varlen_func(
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
return_attn_probs
,
return_attn_probs
,
)
)
...
@@ -805,6 +930,7 @@ def flash_attn_with_kvcache(
...
@@ -805,6 +930,7 @@ def flash_attn_with_kvcache(
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
num_splits
=
0
,
num_splits
=
0
,
):
):
...
@@ -818,11 +944,12 @@ def flash_attn_with_kvcache(
...
@@ -818,11 +944,12 @@ def flash_attn_with_kvcache(
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
...
@@ -843,6 +970,10 @@ def flash_attn_with_kvcache(
...
@@ -843,6 +970,10 @@ def flash_attn_with_kvcache(
1 1
1 1
If the row of the mask is all zero, the output will be zero.
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Note: Does not support backward pass.
Arguments:
Arguments:
...
@@ -860,6 +991,7 @@ def flash_attn_with_kvcache(
...
@@ -860,6 +991,7 @@ def flash_attn_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...
@@ -894,6 +1026,8 @@ def flash_attn_with_kvcache(
...
@@ -894,6 +1026,8 @@ def flash_attn_with_kvcache(
None
,
None
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
1
],
rotary_interleaved
,
rotary_interleaved
,
num_splits
,
num_splits
,
)
)
...
...
tests/test_flash_attn.py
View file @
083e8f52
...
@@ -150,8 +150,13 @@ def generate_qkv(
...
@@ -150,8 +150,13 @@ def generate_qkv(
)
)
def
construct_causal_mask
(
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
):
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
...
@@ -165,7 +170,14 @@ def construct_causal_mask(
...
@@ -165,7 +170,14 @@ def construct_causal_mask(
if
query_padding_mask
is
None
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
)
return
col_idx
>
row_idx
+
sk
-
sq
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
def
attention_ref
(
...
@@ -177,6 +189,7 @@ def attention_ref(
...
@@ -177,6 +189,7 @@ def attention_ref(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -189,6 +202,8 @@ def attention_ref(
...
@@ -189,6 +202,8 @@ def attention_ref(
key_padding_mask: (batch_size, seqlen_k)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...
@@ -198,6 +213,8 @@ def attention_ref(
...
@@ -198,6 +213,8 @@ def attention_ref(
output: (batch_size, seqlen_q, nheads, head_dim)
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
dtype_og
=
q
.
dtype
if
upcast
:
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
...
@@ -211,17 +228,24 @@ def attention_ref(
...
@@ -211,17 +228,24 @@ def attention_ref(
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
# causal_mask = torch.triu(
local_mask
=
construct_local_mask
(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
seqlen_q
,
# )
seqlen_k
,
causal_mask
=
construct_causal_mask
(
window_size
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
)
scores
.
masked_fill_
(
caus
al_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
loc
al_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
if
causal
:
# Some rows are completely masked out so we fill them with zero instead of NaN
# Some rows might be completely masked out so we fill them with zero instead of NaN
attention
=
attention
.
masked_fill
(
torch
.
all
(
causal_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
...
@@ -232,7 +256,6 @@ def attention_ref(
...
@@ -232,7 +256,6 @@ def attention_ref(
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
...
@@ -244,6 +267,7 @@ def attention_kvpacked_ref(
...
@@ -244,6 +267,7 @@ def attention_kvpacked_ref(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -257,6 +281,7 @@ def attention_kvpacked_ref(
...
@@ -257,6 +281,7 @@ def attention_kvpacked_ref(
dropout_mask
,
dropout_mask
,
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -267,6 +292,7 @@ def attention_qkvpacked_ref(
...
@@ -267,6 +292,7 @@ def attention_qkvpacked_ref(
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -280,6 +306,7 @@ def attention_qkvpacked_ref(
...
@@ -280,6 +306,7 @@ def attention_qkvpacked_ref(
dropout_mask
,
dropout_mask
,
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
...
@@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
def
convert_flash_attn_S_to_softmax
(
def
convert_flash_attn_S_to_softmax
(
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
):
"""FlashAttention stores the S matrix in a different way.
"""FlashAttention stores the S matrix in a different way.
Arguments:
Arguments:
...
@@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax(
...
@@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax(
query_padding_mask: (batch_size, seqlen_q_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
warps_n
=
4
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
...
@@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax(
...
@@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax(
four
=
4
,
four
=
4
,
)
)
if
causal
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
# causal_mask = torch.triu(
local_mask
=
construct_local_mask
(
# torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
seqlen_q
,
# )
seqlen_k
,
causal_mask
=
construct_causal_mask
(
window_size
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
S
.
device
query_padding_mask
,
key_padding_mask
,
S
.
device
,
)
)
caus
al_mask
=
F
.
pad
(
loc
al_mask
=
F
.
pad
(
caus
al_mask
,
loc
al_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
value
=
True
,
)
)
S_converted
.
masked_fill_
(
caus
al_mask
,
0.0
)
S_converted
.
masked_fill_
(
loc
al_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
# and some of those values aren't overwritten.
...
@@ -399,6 +438,7 @@ def normalize_flash_attn_S(
...
@@ -399,6 +438,7 @@ def normalize_flash_attn_S(
key_padding_mask
=
None
,
key_padding_mask
=
None
,
is_dropout
=
False
,
is_dropout
=
False
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
):
"""
"""
Arguments:
Arguments:
...
@@ -409,20 +449,24 @@ def normalize_flash_attn_S(
...
@@ -409,20 +449,24 @@ def normalize_flash_attn_S(
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
seqlen_k
=
k
.
shape
[
1
]
seqlen_k
=
k
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
# causal_mask = torch.triu(
local_mask
=
construct_local_mask
(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
seqlen_q
,
# )
seqlen_k
,
causal_mask
=
construct_causal_mask
(
window_size
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
)
scores
.
masked_fill_
(
caus
al_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
loc
al_mask
,
float
(
"-inf"
))
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
...
@@ -446,79 +490,84 @@ def normalize_flash_attn_S(
...
@@ -446,79 +490,84 @@ def normalize_flash_attn_S(
def
get_dropout_fraction
(
def
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
):
"""
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
key_padding_mask: (batch_size, seqlen_k)
"""
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
dropped
=
~
dropout_mask
dropped
=
~
dropout_mask
valid
=
torch
.
ones_like
(
dropout_mask
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
causal
:
valid
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
# causal_mask = torch.triu(
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
local_mask
=
construct_local_mask
(
# )
seqlen_q
,
causal_mask
=
construct_causal_mask
(
seqlen_k
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
window_size
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
,
)
)
dropped
.
masked_fill_
(
causal_mask
,
False
)
dropped
.
masked_fill_
(
local_mask
,
False
)
valid
.
masked_fill_
(
local_mask
,
False
)
dropped_total
=
dropped
.
sum
()
dropped_total
=
dropped
.
sum
()
query_lengths
=
(
return
dropped
.
sum
()
/
valid
.
sum
()
query_padding_mask
.
sum
(
dim
=-
1
)
if
query_padding_mask
is
not
None
else
torch
.
full
((
batch_size
,),
seqlen_q
,
device
=
dropout_mask
.
device
)
)
key_lengths
=
(
key_padding_mask
.
sum
(
dim
=-
1
)
if
key_padding_mask
is
not
None
else
torch
.
full
((
batch_size
,),
seqlen_k
,
device
=
dropout_mask
.
device
)
)
if
not
causal
:
numel_per_batch
=
query_lengths
*
key_lengths
else
:
numel_per_batch
=
torch
.
where
(
key_lengths
<=
query_lengths
,
key_lengths
*
(
key_lengths
+
1
)
/
2
,
query_lengths
*
key_lengths
-
(
query_lengths
*
(
query_lengths
-
1
)
/
2
),
)
return
dropped_total
/
(
numel_per_batch
.
sum
()
*
nheads
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize(
'
causal
'
, [False])
# @pytest.mark.parametrize(
"
causal
"
, [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize(
'd'
, [64])
# @pytest.mark.parametrize(
"d"
, [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize(
'
seqlen
'
, [128])
# @pytest.mark.parametrize(
"
seqlen
"
, [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize(
'
dropout_p
'
, [0.0])
# @pytest.mark.parametrize(
"
dropout_p
"
, [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
)
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
None
,
None
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
).
item
()
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
else
:
dropout_mask
=
None
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
)
# v = qkv[:, :, 2].float()
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
...
@@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
@@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
device
=
"cuda"
...
@@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
batch_size
=
5
nheads
=
6
nheads
=
6
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
)
...
@@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)
)
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_qkvpacked_func
(
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
...
@@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
d
,
d
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
key_padding_mask
,
key_padding_mask
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_fraction
=
get_dropout_fraction
(
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
).
item
()
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
else
:
dropout_mask
=
None
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
qkv
,
...
@@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
@@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
...
@@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
mha_type
,
dtype
,
kvpacked
):
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
@@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
...
@@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if
kvpacked
:
if
kvpacked
:
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
)
else
:
else
:
out
,
lse
,
S_dmask
=
flash_attn_func
(
out
,
lse
,
S_dmask
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
k_rep
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
k_rep
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_rep
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_rep
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
attn
=
normalize_flash_attn_S
(
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k_rep
,
v_rep
,
None
,
None
,
dropout_p
>
0.0
,
causal
=
causal
attn_unnorm
,
q
,
k_rep
,
v_rep
,
None
,
None
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
).
item
()
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
else
:
dropout_mask
=
None
dropout_mask
=
None
if
kvpacked
:
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
kv
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
else
:
else
:
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
k
,
v
,
None
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
@@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize('mha_type', ["mqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
...
@@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
@@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output(
...
@@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output(
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
...
@@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output(
...
@@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
else
:
else
:
(
(
...
@@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output(
...
@@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
...
@@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output(
...
@@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output(
d
,
d
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output(
...
@@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output(
key_padding_mask
,
key_padding_mask
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
dropout_fraction
=
get_dropout_fraction
(
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
dropout_mask
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
,
).
item
()
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
else
:
...
@@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output(
...
@@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output(
if
kvpacked
:
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output(
...
@@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
else
:
else
:
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output(
...
@@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output(
...
@@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output(
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
@@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output(
...
@@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
...
@@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output(
...
@@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output(
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
dtype
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
@@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal
=
True
causal
=
True
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
k
,
k
,
...
@@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize(
'd'
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize(
"d"
, [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [
128
])
# @pytest.mark.parametrize("d", [
64
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
],
],
)
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
dtype
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
...
@@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal
=
True
causal
=
True
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
6
batch_size
=
1
3
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
@@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
max_seqlen_k
,
max_seqlen_k
,
0.0
,
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
...
@@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
if
swap_sq_sk
:
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
device
=
"cuda"
...
@@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
batch_size
=
1
nheads
=
12
nheads
=
12
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
return_attn_probs
=
True
)
out
,
lse
,
_
=
flash_attn_func
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
k
,
k
,
...
@@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("new_kv", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
...
@@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
...
@@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache(
...
@@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache(
rotary_interleaved
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
seqlen_new_eq_seqlen_q
,
causal
,
causal
,
local
,
new_kv
,
new_kv
,
mha_type
,
mha_type
,
num_splits
,
num_splits
,
...
@@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache(
...
@@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache(
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
if
new_kv
:
...
@@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache(
...
@@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache(
cache_seqlens
=
torch
.
randint
(
cache_seqlens
=
torch
.
randint
(
0
,
0
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
causal
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
if
new_kv
else
(
seqlen_k
+
1
),
else
(
seqlen_k
+
1
),
(
batch_size
,),
(
batch_size
,),
...
@@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache(
...
@@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache(
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
:
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
)
...
@@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache(
...
@@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache(
sin
,
sin
,
cache_seqlens
,
cache_seqlens
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
rotary_interleaved
=
rotary_interleaved
,
num_splits
=
num_splits
,
num_splits
=
num_splits
,
)
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
...
@@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache(
...
@@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache(
# probs = torch.softmax(qk, dim=-1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
_
=
attention_ref
(
out_pt
,
_
=
attention_ref
(
q_ro
,
q_ro
,
...
@@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache(
...
@@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache(
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment