Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8f873cc6
Unverified
Commit
8f873cc6
authored
Jul 08, 2024
by
Nicolas Patry
Committed by
GitHub
Jul 08, 2024
Browse files
Implement softcapping. (#1025)
* Softcap v2 (fwd only). * Some missing interface + remove overrides in tests.
parent
4e8d6006
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
185 additions
and
48 deletions
+185
-48
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+30
-4
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+1
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+35
-6
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+39
-35
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+10
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+42
-1
setup.py
setup.py
+1
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+27
-2
No files found.
csrc/flash_attn/flash_api.cpp
View file @
8f873cc6
...
...
@@ -43,6 +43,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
float
softmax_scale
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
bool
seqlenq_ngroups_swapped
=
false
,
const
bool
unpadded_lse
=
false
)
{
...
...
@@ -100,8 +101,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
params
.
d_rounded
=
d_rounded
;
// Set the different scale values.
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
TORCH_CHECK
(
softcap
<=
0.0
,
"This flash attention build does not support softcap."
);
#endif
if
(
softcap
>
0.0
)
{
params
.
softcap
=
softmax_scale
/
softcap
;
params
.
scale_softmax
=
softcap
;
params
.
scale_softmax_log2
=
softcap
*
M_LOG2E
;
}
else
{
// Remove potential NaN
params
.
softcap
=
0.0
;
params
.
scale_softmax
=
softmax_scale
;
params
.
scale_softmax_log2
=
softmax_scale
*
M_LOG2E
;
}
// Set this to probability of keeping an element to simplify things.
params
.
p_dropout
=
1.
f
-
p_dropout
;
...
...
@@ -172,6 +184,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
float
softmax_scale
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
bool
deterministic
,
const
bool
unpadded_lse
)
{
...
...
@@ -187,6 +200,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
softmax_scale
,
window_size_left
,
window_size_right
,
softcap
,
false
,
// seqlenq_ngroups_swapped
unpadded_lse
);
...
...
@@ -332,6 +346,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -453,7 +468,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
p_dropout
,
softmax_scale
,
window_size_left
,
window_size_right
);
window_size_right
,
softcap
);
set_params_splitkv
(
params
,
batch_size
,
num_heads
,
...
...
@@ -521,6 +538,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -688,6 +706,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_scale
,
window_size_left
,
window_size_right
,
softcap
,
seqlenq_ngroups_swapped
,
/*unpadded_lse*/
true
);
params
.
total_q
=
total_q
;
...
...
@@ -776,6 +795,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
deterministic
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -940,6 +960,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale
,
window_size_left
,
window_size_right
,
softcap
,
deterministic
,
/*unpadded_lse*/
false
);
params
.
dq_accum_split_stride
=
!
deterministic
?
0
:
dq_accum
.
stride
(
0
);
...
...
@@ -1009,6 +1030,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
deterministic
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -1191,6 +1213,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_scale
,
window_size_left
,
window_size_right
,
softcap
,
deterministic
,
/*unpadded_lse*/
true
);
params
.
dq_accum_split_stride
=
!
deterministic
?
0
:
dq_accum
.
stride
(
0
);
...
...
@@ -1257,6 +1280,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
)
{
...
...
@@ -1392,7 +1416,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
/*p_dropout=*/
0.
f
,
softmax_scale
,
window_size_left
,
window_size_right
);
window_size_right
,
softcap
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
...
...
csrc/flash_attn/src/flash.h
View file @
8f873cc6
...
...
@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
int
window_size_left
,
window_size_right
;
float
softcap
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
8f873cc6
...
...
@@ -22,6 +22,22 @@ namespace flash {
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_softcap
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
float
softcap
){
static_assert
(
Layout
::
rank
==
3
,
"Only support 3D Tensor"
);
static_assert
(
decltype
(
size
<
0
>
(
tensor
))
::
value
==
4
,
"First dimension must be 4"
);
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
tensor
);
++
i
){
// MMA
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tensor
);
++
mi
){
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
2
>
(
tensor
);
++
nj
){
tensor
(
i
,
mi
,
nj
)
=
cutlass
::
fast_tanh
(
tensor
(
i
,
mi
,
nj
)
*
softcap
);
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
ElementAccum
,
typename
Params
,
int
kBlockM
,
bool
Is_even_MN
>
...
...
@@ -45,7 +61,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -318,6 +334,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
smem_thr_copy_Q
,
smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
...
@@ -381,6 +400,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -486,7 +508,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -870,6 +892,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
smem_thr_copy_Q
,
smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
...
@@ -941,6 +967,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -1054,7 +1083,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -1070,12 +1099,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -1084,7 +1113,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
8f873cc6
...
...
@@ -26,18 +26,18 @@
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
)
{
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
)
{
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
...
...
@@ -67,12 +67,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
...
...
@@ -91,6 +92,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
}
template
<
typename
Kernel_traits
>
...
...
@@ -109,10 +111,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
...
...
@@ -128,6 +131,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
if
(
params
.
num_splits
>
1
)
{
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
...
...
csrc/flash_attn/src/static_switch.h
View file @
8f873cc6
...
...
@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
...
...
flash_attn/flash_attn_interface.py
View file @
8f873cc6
...
...
@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
...
...
@@ -59,6 +59,7 @@ def _flash_attn_forward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax
,
None
,
)
...
...
@@ -123,6 +124,7 @@ def _flash_attn_backward(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
...
...
@@ -151,6 +153,7 @@ def _flash_attn_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
...
...
@@ -176,6 +179,7 @@ def _flash_attn_varlen_backward(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
...
...
@@ -209,6 +213,7 @@ def _flash_attn_varlen_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
...
...
@@ -227,6 +232,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -241,6 +247,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
...
...
@@ -249,6 +256,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -272,6 +280,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -433,6 +442,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -451,6 +461,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
...
...
@@ -464,6 +475,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -492,6 +504,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -512,6 +525,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -526,6 +540,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
...
...
@@ -534,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -556,6 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -581,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -600,6 +618,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
...
...
@@ -613,6 +632,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -639,6 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -655,6 +676,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# <=0.0 means deactivate
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -676,6 +698,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
...
@@ -698,6 +721,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcapping
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -711,6 +735,7 @@ def flash_attn_kvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -748,6 +773,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -772,6 +798,7 @@ def flash_attn_kvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -786,6 +813,7 @@ def flash_attn_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -846,6 +874,7 @@ def flash_attn_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -860,6 +889,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -884,6 +914,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
...
@@ -908,6 +939,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -925,6 +957,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -968,6 +1001,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -996,6 +1030,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -1014,6 +1049,7 @@ def flash_attn_varlen_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -1056,6 +1092,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -1085,6 +1122,7 @@ def flash_attn_varlen_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -1106,6 +1144,7 @@ def flash_attn_with_kvcache(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
alibi_slopes
=
None
,
num_splits
=
0
,
...
...
@@ -1177,6 +1216,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...
...
@@ -1226,6 +1266,7 @@ def flash_attn_with_kvcache(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
rotary_interleaved
,
num_splits
,
)
...
...
setup.py
View file @
8f873cc6
...
...
@@ -203,6 +203,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
...
...
tests/test_flash_attn.py
View file @
8f873cc6
...
...
@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
reorder_ops
=
False
,
):
...
...
@@ -253,6 +254,10 @@ def attention_ref(
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
softcap
>
0
:
scores
/=
softcap
scores
=
scores
.
tanh
()
scores
*=
softcap
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
...
...
@@ -877,8 +882,9 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -894,6 +900,9 @@ def test_flash_attn_output(
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
...
...
@@ -918,6 +927,7 @@ def test_flash_attn_output(
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
...
...
@@ -930,6 +940,7 @@ def test_flash_attn_output(
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
...
...
@@ -984,6 +995,7 @@ def test_flash_attn_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
...
...
@@ -995,6 +1007,7 @@ def test_flash_attn_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1010,6 +1023,7 @@ def test_flash_attn_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1022,6 +1036,7 @@ def test_flash_attn_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1133,9 +1148,10 @@ def test_flash_attn_output(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -1151,6 +1167,9 @@ def test_flash_attn_varlen_output(
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
kv
=
torch
.
randn
(
...
...
@@ -1199,6 +1218,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
...
...
@@ -1230,6 +1250,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
...
...
@@ -1289,6 +1310,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
...
...
@@ -1300,6 +1322,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
@@ -1315,6 +1338,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1327,6 +1351,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment