Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5ab9b366
Commit
5ab9b366
authored
Dec 21, 2023
by
Tri Dao
Browse files
Clean up alibi, implement non-causal alibi
parent
bc28eacc
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
309 additions
and
1199 deletions
+309
-1199
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+27
-32
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+27
-19
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+0
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+10
-30
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+3
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-40
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+2
-2
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+5
-7
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+25
-6
tests/test_alibi.py
tests/test_alibi.py
+0
-1006
tests/test_flash_attn.py
tests/test_flash_attn.py
+202
-49
No files found.
csrc/flash_attn/flash_api.cpp
View file @
5ab9b366
...
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
if
(
seqlen_k
>
0
)
{
...
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
...
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
if
(
seqlen_q
>
0
)
{
...
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
float
p_dropout
,
// probability to drop
...
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
launch
(
params
,
stream
,
/*configure=*/
false
);
...
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
// batch_size x num_heads
int
num_splits
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
csrc/flash_attn/src/alibi.h
View file @
5ab9b366
...
...
@@ -13,37 +13,45 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
template
<
bool
Is_causal
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
head_idx
,
const
float
softmax_scale
,
const
float
alibi_slope
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
float
alibi_slope_unscaled
=
alibi_slope
/
softmax_scale
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row
_idx
=
row
_idx_
ba
se
+
i
*
8
;
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col
_idx
_base
=
col
_idx_
off
se
t
+
nj
*
8
;
#pragma unroll
for
(
int
n
j
=
0
;
n
j
<
size
<
1
,
1
>
(
tensor
);
++
n
j
)
{
const
int
col_idx
_base
=
col_idx_
off
se
t
+
nj
*
8
;
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_
ba
se
+
j
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
const
float
alibi
=
alibi_slope_unscaled
*
col_idx
;
if
(
col_idx
<
max_seqlen_k
&&
row_idx
<
max_seqlen_q
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
+=
alibi
;
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
}
}
}
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
}
...
...
@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
}
}
}
// namespace flash
\ No newline at end of file
}
// namespace flash
csrc/flash_attn/src/flash.h
View file @
5ab9b366
...
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int
num_splits
;
// For split-KV version
// float alibi_start;
// float alibi_ratio;
bool
has_alibi
;
void
*
__restrict__
alibi_slopes_ptr
;
index_t
alibi_slopes_batch_stride
;
};
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5ab9b366
...
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
INFINITY
;
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dk
);
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
...
...
@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
clear
(
acc_dq
);
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
...
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5ab9b366
...
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
5ab9b366
...
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// can produce Inf / NaN.
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5ab9b366
...
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
...
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
...
...
csrc/flash_attn/src/softmax.h
View file @
5ab9b366
...
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
...
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
_
,
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
flash_attn/flash_attn_interface.py
View file @
5ab9b366
...
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
k
,
v
,
None
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
None
,
)
...
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q
,
cu_seqlens_k
,
None
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
...
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
None
,
)
...
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq
,
dk
,
dv
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
None
,
rng_state
,
)
...
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
...
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
None
,
rng_state
,
)
...
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
num_splits
=
0
,
alibi_slopes
=
None
,
num_splits
=
0
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
...
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos
,
rotary_sin
,
cache_batch_idx
,
alibi_slopes
,
None
,
softmax_scale
,
causal
,
...
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size
[
1
],
rotary_interleaved
,
num_splits
,
alibi_slopes
,
)
return
out
tests/test_alibi.py
deleted
100644 → 0
View file @
bc28eacc
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
from
flash_attn.flash_attn_triton
import
\
flash_attn_func
as
flash_attn_func_triton
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
is_sm75
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
7
,
5
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
==
8
is_sm80
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
8
,
0
)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
def
generate_alibi
(
max_seq_len
,
num_attention_heads
,
tp_world_size
,
tp_index
,
key_padding_mask
=
None
,
device
=
"cuda"
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
(
2
**
(
-
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][
:
n
-
closest_power_of_2
]
slopes
=
torch
.
tensor
(
get_slopes
(
num_attention_heads
)).
to
(
device
=
device
)
# Select the part of the tensor that corresponds to our tensor parallel index.
assert
(
num_attention_heads
/
tp_world_size
).
is_integer
(
),
"it works only when (num_attention_heads/tp_world_size) is integer"
nh_tp
=
num_attention_heads
//
tp_world_size
slopes
=
slopes
[
nh_tp
*
tp_index
:
nh_tp
*
(
tp_index
+
1
)]
if
(
key_padding_mask
is
None
):
arange_tensor
=
rearrange
(
torch
.
arange
(
max_seq_len
),
"sqk -> 1 sqk"
).
to
(
device
=
device
)
else
:
arange_tensor
=
(
key_padding_mask
.
cumsum
(
dim
=-
1
,
dtype
=
slopes
.
dtype
)
-
1
)
\
.
masked_fill_
(
~
key_padding_mask
,
torch
.
finfo
(
torch
.
float
).
min
).
to
(
device
=
device
)
arange_tensor
=
rearrange
(
arange_tensor
,
'b sqk -> b 1 1 sqk'
)
# (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
alibi_tensor
=
rearrange
(
slopes
,
'nh -> 1 nh 1 1'
)
*
arange_tensor
return
alibi_tensor
,
slopes
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
,
right_padding
=
True
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
if
right_padding
:
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
else
:
padding_mask
=
(
repeat
(
torch
.
arange
(
start
=
max_seqlen
-
1
,
end
=-
1
,
step
=-
1
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
def
output_pad_fn
(
output_unpad
):
return
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
def
output_pad_fn
(
output_unpad
):
return
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
def
dqkv_pad_fn
(
dqkv_unpad
):
return
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
def
dqkv_pad_fn
(
dqkv_unpad
):
return
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
def
dkv_pad_fn
(
dkv_unpad
):
return
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
def
dkv_pad_fn
(
dkv_unpad
):
return
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
def
dk_pad_fn
(
dk_unpad
):
return
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
def
dk_pad_fn
(
dk_unpad
):
return
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
bias
=
None
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
seqlen_q
,
seqlen_k
=
q
.
shape
[
1
],
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
d
=
q
.
shape
[
-
1
]
if
not
reorder_ops
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
bias
is
not
None
:
bias
=
bias
.
to
(
scores
.
dtype
)
scores
+=
bias
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
def
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
def
generate_sparsity_mask
(
seqlen
,
sparsity
=
0.3
):
repeats
=
seqlen
//
16
//
2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow
,
ncol
=
seqlen
//
16
,
seqlen
//
256
mask
=
torch
.
rand
(
nrow
,
ncol
,
device
=
"cuda"
)
<
sparsity
return
mask
def
attention_blocksparse_ref
(
qkv
,
blockmask
,
attn_mask
,
dropout_p
,
dropout_mask
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q
,
k
,
v
=
qkv
.
float
().
unbind
(
dim
=
2
)
d
=
qkv
.
shape
[
-
1
]
seqlen
=
qkv
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
scores
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
blockmask
=
repeat
(
blockmask
,
"s_16 s_256 -> (s_16 16) (s_256 256)"
)
blockmask
=
blockmask
[:
seqlen
,
:
seqlen
]
scores
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention
=
attention
.
masked_fill
(
rearrange
(
~
attn_mask
,
"b s -> b 1 s 1"
),
0.0
)
attention
=
attention
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
0.0
)
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1
-
dropout_p
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
output
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b s 1 1"
),
0
)
return
output
.
to
(
dtype
=
qkv
.
dtype
),
attention
.
to
(
dtype
=
qkv
.
dtype
)
def
convert_flash_attn_S_to_softmax
(
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
nblocks_n
=
(
seqlen_k_rounded
+
blocksize_n
-
1
)
//
blocksize_n
nblocks_m
=
(
seqlen_q_rounded
+
blocksize_m
-
1
)
//
blocksize_m
mmas_n
=
(
blocksize_n
+
16
-
1
)
//
16
S_flat
=
rearrange
(
S
,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)"
,
blocksize_m
=
blocksize_m
,
blocksize_n
=
blocksize_n
,
)
S_converted
=
rearrange
(
S_flat
,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)"
,
mmas_n
=
mmas_n
,
warps_n
=
warps_n
,
eight
=
8
,
c0
=
2
,
c1
=
2
,
c2
=
2
,
four
=
4
,
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
S
.
device
,
)
local_mask
=
F
.
pad
(
local_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
)
S_converted
.
masked_fill_
(
local_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og
=
(
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q_rounded
)
if
query_padding_mask
is
not
None
:
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
seqlen_k_og
=
key_padding_mask
.
shape
[
-
1
]
if
key_padding_mask
is
not
None
else
seqlen_k
if
key_padding_mask
is
not
None
:
key_padding_mask
=
F
.
pad
(
key_padding_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
0.0
)
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
0
,
0
,
seqlen_q_og
-
seqlen_q_rounded
))
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
seqlen_k_og
-
seqlen_k_rounded
))
return
S_converted
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
def
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
is_dropout
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
seqlen_k
=
k
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse
[
lse
==
float
(
"-inf"
)]
=
float
(
"inf"
)
scores_max_block
=
torch
.
stack
(
[
torch
.
amax
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
cummax_block
=
torch
.
cummax
(
scores_max_block
.
flip
(
-
1
),
dim
=-
1
).
values
.
flip
(
-
1
).
unbind
(
dim
=-
1
)
attn_unnorm_block
=
attn_unnorm
.
split
(
block_size_n
,
dim
=-
1
)
attn_norm
=
torch
.
cat
(
[
a
*
rearrange
(
torch
.
exp
(
m
-
lse
),
"b h s -> b h s 1"
)
for
a
,
m
in
zip
(
attn_unnorm_block
,
cummax_block
)
],
dim
=-
1
,
)
if
query_padding_mask
is
not
None
:
attn_norm
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
return
attn_norm
.
to
(
dtype
=
attn_unnorm
.
dtype
)
def
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
dropped
=
~
dropout_mask
valid
=
torch
.
ones_like
(
dropout_mask
)
if
query_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
,
)
dropped
.
masked_fill_
(
local_mask
,
False
)
valid
.
masked_fill_
(
local_mask
,
False
)
dropped_total
=
dropped
.
sum
()
return
dropped
.
sum
()
/
valid
.
sum
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
)
@
pytest
.
mark
.
parametrize
(
"b_sq"
,
[
(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)
]
)
@
pytest
.
mark
.
parametrize
(
"nh_hd"
,
[
(
32
,
64
),
(
16
,
128
),
(
40
,
128
)
# non power of 2 nh
]
)
@
pytest
.
mark
.
parametrize
(
"tp_world_size"
,
[
1
,
2
,
4
]
)
def
test_flash_attn_func
(
b_sq
,
nh_hd
,
tp_world_size
,
dtype
):
b
,
sq
=
b_sq
nh
,
hd
=
nh_hd
nh_tp
=
nh
//
tp_world_size
q
,
k
,
v
=
[
torch
.
randn
(
b
,
sq
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
dout
=
torch
.
rand_like
(
q
)
for
tp_index
in
range
(
tp_world_size
):
alibi
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
sq
,
num_attention_heads
=
nh
,
tp_world_size
=
tp_world_size
,
tp_index
=
tp_index
,
key_padding_mask
=
None
,
device
=
"cuda"
)
triton_out
=
flash_attn_func_triton
(
q
,
k
,
v
,
alibi
,
True
,
hd
**
(
-
0.5
))
triton_out
.
backward
(
dout
)
triton_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
triton_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
triton_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
flash_out
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
True
,
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
b
))
flash_out
.
backward
(
dout
)
flash_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
flash_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
flash_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
assert
torch
.
allclose
(
flash_out
,
triton_out
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dq
,
triton_dq
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dk
,
triton_dk
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dv
,
triton_dv
,
atol
=
1e-2
,
rtol
=
0.
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
)
@
pytest
.
mark
.
parametrize
(
"right_padding"
,
[
True
,
False
]
)
@
pytest
.
mark
.
parametrize
(
"b_sq"
,
[
(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)
]
)
@
pytest
.
mark
.
parametrize
(
"nh_hd"
,
[
(
32
,
64
),
(
16
,
128
),
(
40
,
128
)
# non power of 2 nh
]
)
@
pytest
.
mark
.
parametrize
(
"tp_world_size"
,
[
1
,
2
,
4
]
)
def
test_flash_attn_varlen_func
(
b_sq
,
nh_hd
,
tp_world_size
,
right_padding
,
dtype
):
b
,
sqk
=
b_sq
nh
,
hd
=
nh_hd
nh_tp
=
nh
//
tp_world_size
# flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
# so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
# https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
q
=
torch
.
randn
(
b
,
1
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
k
,
v
=
[
torch
.
randn
(
b
,
sqk
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
2
)]
dout
=
torch
.
rand_like
(
q
)
padding_mask
=
generate_random_padding_mask
(
sqk
,
b
,
"cuda"
,
"random"
,
right_padding
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
None
,
padding_mask
,
kvpacked
=
False
)
for
tp_index
in
range
(
tp_world_size
):
alibi
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
sqk
,
num_attention_heads
=
nh
,
tp_world_size
=
tp_world_size
,
tp_index
=
tp_index
,
key_padding_mask
=
padding_mask
,
device
=
"cuda"
)
triton_out
=
flash_attn_func_triton
(
q
,
k
,
v
,
alibi
,
False
,
hd
**
(
-
0.5
))
triton_out
.
backward
(
dout
)
triton_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
triton_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
triton_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
flash_out_unpad
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
=
True
,
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
b
)
)
flash_out
=
output_pad_fn
(
flash_out_unpad
)
flash_out
.
backward
(
dout
)
flash_dq_unpad
,
q_unpad
.
grad
=
q_unpad
.
grad
.
clone
(),
None
flash_dk_unpad
,
k_unpad
.
grad
=
k_unpad
.
grad
.
clone
(),
None
flash_dv_unpad
,
v_unpad
.
grad
=
v_unpad
.
grad
.
clone
(),
None
flash_dq
=
dq_pad_fn
(
flash_dq_unpad
)
flash_dk
=
dk_pad_fn
(
flash_dk_unpad
)
flash_dv
=
dk_pad_fn
(
flash_dv_unpad
)
assert
torch
.
allclose
(
flash_out
,
triton_out
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dq
,
triton_dq
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dk
,
triton_dk
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dv
,
triton_dv
,
atol
=
1e-2
,
rtol
=
0.
)
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("local", [False, True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
# @pytest.mark.parametrize("causal", [False, True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
has_batch_idx
,
rotary_fraction
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
new_kv
,
mha_type
,
num_splits
,
dtype
,
alibi
,
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
if
not
new_kv
and
rotary_fraction
>
0.0
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
nheads
=
8
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
4
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
else
:
cache_batch_idx
=
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
q_ro
=
rearrange
(
apply_rotary_emb
(
rearrange
(
q
,
"b s h d -> b 1 (s h) d"
),
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
,
),
"b 1 (s h) d -> b s h d"
,
s
=
seqlen_q
,
)
# q_ro = q
k_ro
=
apply_rotary_emb
(
k
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]).
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k_ro
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
if
alibi
:
seqlen_alibi
=
k_cache_rep
.
shape
[
1
]
alibi_tensor
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
seqlen_alibi
,
num_attention_heads
=
nheads
,
tp_world_size
=
1
,
tp_index
=
0
,
key_padding_mask
=
None
,
device
=
"cuda"
)
# alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
batch_size
)
if
alibi_tensor
.
abs
().
max
().
item
()
>=
torch
.
finfo
(
dtype
).
max
:
pytest
.
skip
()
else
:
alibi_tensor
,
alibi_slopes
=
None
,
None
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cos
,
sin
,
cache_seqlens
,
cache_batch_idx
,
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
num_splits
=
num_splits
,
alibi_slopes
=
alibi_slopes
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
\
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
bias
=
alibi_tensor
)
out_pt
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
bias
=
alibi_tensor
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if
new_kv
:
k_cache_select
=
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
\
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
tests/test_flash_attn.py
View file @
5ab9b366
...
...
@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
def
attn_bias_from_alibi_slopes
(
slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
):
batch
,
nheads
=
slopes
.
shape
device
=
slopes
.
device
slopes
=
rearrange
(
slopes
,
"b h -> b h 1 1"
)
if
causal
:
return
torch
.
arange
(
-
seqlen_k
+
1
,
1
,
device
=
device
,
dtype
=
torch
.
float32
)
*
slopes
else
:
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
relative_pos
=
torch
.
abs
(
row_idx
+
sk
-
sq
-
col_idx
)
return
-
slopes
*
relative_pos
.
to
(
dtype
=
slopes
.
dtype
)
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
...
...
@@ -186,6 +211,7 @@ def attention_ref(
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
...
...
@@ -200,6 +226,7 @@ def attention_ref(
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
...
...
@@ -238,7 +265,9 @@ def attention_ref(
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
...
...
@@ -264,6 +293,7 @@ def attention_kvpacked_ref(
kv
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
...
...
@@ -277,6 +307,7 @@ def attention_kvpacked_ref(
kv
[:,
:,
1
],
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
...
...
@@ -289,6 +320,7 @@ def attention_kvpacked_ref(
def
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
...
...
@@ -302,6 +334,7 @@ def attention_qkvpacked_ref(
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
...
...
@@ -436,6 +469,7 @@ def normalize_flash_attn_S(
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
is_dropout
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
...
...
@@ -445,6 +479,7 @@ def normalize_flash_attn_S(
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
...
...
@@ -467,6 +502,8 @@ def normalize_flash_attn_S(
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
.
to
(
dtype
=
scores
.
dtype
)
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
...
...
@@ -529,6 +566,8 @@ def get_dropout_fraction(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
...
...
@@ -538,24 +577,34 @@ def get_dropout_fraction(
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize("seqlen", [
128
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
384
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize("seqlen", [
97
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
...
...
@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
qkv
[:,
:,
2
],
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
...
...
@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
qkv
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -651,7 +702,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
...
@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
257
,
384
,
512
,
768
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
...
...
@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
*
qkv
.
unbind
(
dim
=
2
),
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
...
...
@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
...
...
@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
...
...
@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
qkv
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -774,7 +845,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
...
@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
...
...
@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
mha_type
,
dtype
,
kvpacked
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -825,7 +900,7 @@ def test_flash_attn_output(
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
...
...
@@ -842,14 +917,32 @@ def test_flash_attn_output(
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
else
:
out
,
lse
,
S_dmask
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
...
...
@@ -878,6 +971,7 @@ def test_flash_attn_output(
v_rep
,
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
...
...
@@ -895,6 +989,7 @@ def test_flash_attn_output(
kv
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -905,6 +1000,7 @@ def test_flash_attn_output(
kv
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -919,6 +1015,7 @@ def test_flash_attn_output(
v
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -930,6 +1027,7 @@ def test_flash_attn_output(
v
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -1000,7 +1098,9 @@ def test_flash_attn_output(
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
...
@@ -1014,11 +1114,13 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1041,7 +1143,7 @@ def test_flash_attn_output(
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
mha_type
,
dtype
,
kvpacked
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output(
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
...
...
@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output(
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
(
...
...
@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
else
:
(
...
...
@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
...
...
@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output(
v_rep
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
...
...
@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output(
kv
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output(
kv
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output(
v
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output(
v
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
...
...
@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output(
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
...
...
@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
...
@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
q
,
k
,
v
,
None
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v
,
None
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
...
@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v
,
query_padding_mask
,
key_padding_mask
,
None
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v
,
query_padding_mask
,
key_padding_mask
,
None
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [
Tru
e])
# @pytest.mark.parametrize("local", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
...
...
@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
alibi
,
dtype
):
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
...
...
@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
q
,
k
,
v
,
None
,
None
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
...
...
@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
v
,
None
,
None
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult
=
2
if
not
alibi
else
8
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
#
@pytest.mark.parametrize("dtype", [torch.float16])
#
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [
0
])
# @pytest.mark.parametrize("num_splits", [
1
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [
Tru
e])
# @pytest.mark.parametrize("causal", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
...
...
@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [
Tru
e])
# @pytest.mark.parametrize("has_batch_idx", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
...
@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache(
seqlen_new_eq_seqlen_q
,
causal
,
local
,
alibi
,
new_kv
,
mha_type
,
num_splits
,
...
...
@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache(
dtype
=
torch
.
int32
,
device
=
device
,
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
:
batch_size
]
else
:
cache_batch_idx
=
None
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
None
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
...
...
@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache(
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]).
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
...
...
@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache(
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
alibi_slopes
=
alibi_slopes
,
num_splits
=
num_splits
,
)
# out = flash_attn_with_kvcache(
...
...
@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache(
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache(
v_cache_rep
,
None
,
key_padding_mask
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
...
...
@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache(
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult
=
3
if
not
alibi
else
5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
mult
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment