Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5ab9b366
Commit
5ab9b366
authored
Dec 21, 2023
by
Tri Dao
Browse files
Clean up alibi, implement non-causal alibi
parent
bc28eacc
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
309 additions
and
1199 deletions
+309
-1199
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+27
-32
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+27
-19
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+0
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+10
-30
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+3
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-40
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+2
-2
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+5
-7
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+25
-6
tests/test_alibi.py
tests/test_alibi.py
+0
-1006
tests/test_flash_attn.py
tests/test_flash_attn.py
+202
-49
No files found.
csrc/flash_attn/flash_api.cpp
View file @
5ab9b366
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
if
(
seqlen_k
>
0
)
{
if
(
seqlen_k
>
0
)
{
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
const
float
p_dropout
,
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
if
(
seqlen_q
>
0
)
{
if
(
seqlen_q
>
0
)
{
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
launch
(
params
,
stream
,
/*configure=*/
false
);
launch
(
params
,
stream
,
/*configure=*/
false
);
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
,
int
num_splits
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
// batch_size x num_heads
)
{
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
csrc/flash_attn/src/alibi.h
View file @
5ab9b366
...
@@ -13,37 +13,45 @@ using namespace cute;
...
@@ -13,37 +13,45 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
template
<
bool
Is_causal
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
warp_row_stride
,
const
int
head_idx
,
const
float
softmax_scale
,
const
float
alibi_slope
)
{
const
float
alibi_slope
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
float
alibi_slope_unscaled
=
alibi_slope
/
softmax_scale
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
row
_idx
=
row
_idx_
ba
se
+
i
*
8
;
const
int
col
_idx
_base
=
col
_idx_
off
se
t
+
nj
*
8
;
#pragma unroll
#pragma unroll
for
(
int
n
j
=
0
;
n
j
<
size
<
1
,
1
>
(
tensor
);
++
n
j
)
{
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
_base
=
col_idx_
off
se
t
+
nj
*
8
;
const
int
col_idx
=
col_idx_
ba
se
+
j
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
const
float
alibi
=
alibi_slope_unscaled
*
col_idx
;
}
if
(
col_idx
<
max_seqlen_k
&&
row_idx
<
max_seqlen_q
)
{
}
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
+=
alibi
;
}
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
}
}
}
}
...
@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
...
@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
}
}
}
}
}
// namespace flash
}
// namespace flash
\ No newline at end of file
csrc/flash_attn/src/flash.h
View file @
5ab9b366
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int
num_splits
;
// For split-KV version
int
num_splits
;
// For split-KV version
// float alibi_start;
// float alibi_ratio;
bool
has_alibi
;
void
*
__restrict__
alibi_slopes_ptr
;
void
*
__restrict__
alibi_slopes_ptr
;
index_t
alibi_slopes_batch_stride
;
index_t
alibi_slopes_batch_stride
;
};
};
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5ab9b366
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
INFINITY
;
}
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK);
// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dv
);
clear
(
acc_dk
);
clear
(
acc_dk
);
float
alibi_slope
=
0.0
f
;
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
...
@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
// if (cute::thread(32, 0)) { print(scores); }
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
clear
(
acc_dq
);
clear
(
acc_dq
);
float
alibi_slope
=
0.0
f
;
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
for
(;
n_block
>=
0
;
--
n_block
)
{
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5ab9b366
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
5ab9b366
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// can produce Inf / NaN.
// can produce Inf / NaN.
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5ab9b366
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
...
...
csrc/flash_attn/src/softmax.h
View file @
5ab9b366
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
_
,
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
flash_attn/flash_attn_interface.py
View file @
5ab9b366
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
k
,
k
,
v
,
v
,
None
,
None
,
alibi_slopes
,
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
None
,
None
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq
,
dq
,
dk
,
dk
,
dv
,
dv
,
alibi_slopes
,
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv
,
dv
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -732,6 +737,9 @@ def flash_attn_func(
...
@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
num_splits
=
0
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
num_splits
=
0
,
):
):
"""
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
to automatically determine the number of splits.
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos
,
rotary_cos
,
rotary_sin
,
rotary_sin
,
cache_batch_idx
,
cache_batch_idx
,
alibi_slopes
,
None
,
None
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size
[
1
],
window_size
[
1
],
rotary_interleaved
,
rotary_interleaved
,
num_splits
,
num_splits
,
alibi_slopes
,
)
)
return
out
return
out
tests/test_alibi.py
deleted
100644 → 0
View file @
bc28eacc
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
from
flash_attn.flash_attn_triton
import
\
flash_attn_func
as
flash_attn_func_triton
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
is_sm75
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
7
,
5
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
==
8
is_sm80
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
8
,
0
)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
def
generate_alibi
(
max_seq_len
,
num_attention_heads
,
tp_world_size
,
tp_index
,
key_padding_mask
=
None
,
device
=
"cuda"
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
(
2
**
(
-
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][
:
n
-
closest_power_of_2
]
slopes
=
torch
.
tensor
(
get_slopes
(
num_attention_heads
)).
to
(
device
=
device
)
# Select the part of the tensor that corresponds to our tensor parallel index.
assert
(
num_attention_heads
/
tp_world_size
).
is_integer
(
),
"it works only when (num_attention_heads/tp_world_size) is integer"
nh_tp
=
num_attention_heads
//
tp_world_size
slopes
=
slopes
[
nh_tp
*
tp_index
:
nh_tp
*
(
tp_index
+
1
)]
if
(
key_padding_mask
is
None
):
arange_tensor
=
rearrange
(
torch
.
arange
(
max_seq_len
),
"sqk -> 1 sqk"
).
to
(
device
=
device
)
else
:
arange_tensor
=
(
key_padding_mask
.
cumsum
(
dim
=-
1
,
dtype
=
slopes
.
dtype
)
-
1
)
\
.
masked_fill_
(
~
key_padding_mask
,
torch
.
finfo
(
torch
.
float
).
min
).
to
(
device
=
device
)
arange_tensor
=
rearrange
(
arange_tensor
,
'b sqk -> b 1 1 sqk'
)
# (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
alibi_tensor
=
rearrange
(
slopes
,
'nh -> 1 nh 1 1'
)
*
arange_tensor
return
alibi_tensor
,
slopes
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
,
right_padding
=
True
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
if
right_padding
:
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
else
:
padding_mask
=
(
repeat
(
torch
.
arange
(
start
=
max_seqlen
-
1
,
end
=-
1
,
step
=-
1
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
def
output_pad_fn
(
output_unpad
):
return
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
def
output_pad_fn
(
output_unpad
):
return
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
def
dqkv_pad_fn
(
dqkv_unpad
):
return
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
def
dqkv_pad_fn
(
dqkv_unpad
):
return
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
def
dkv_pad_fn
(
dkv_unpad
):
return
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
def
dkv_pad_fn
(
dkv_unpad
):
return
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
def
dk_pad_fn
(
dk_unpad
):
return
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
def
dk_pad_fn
(
dk_unpad
):
return
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
bias
=
None
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
seqlen_q
,
seqlen_k
=
q
.
shape
[
1
],
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
d
=
q
.
shape
[
-
1
]
if
not
reorder_ops
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
bias
is
not
None
:
bias
=
bias
.
to
(
scores
.
dtype
)
scores
+=
bias
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
query_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
def
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
reorder_ops
=
reorder_ops
,
)
def
generate_sparsity_mask
(
seqlen
,
sparsity
=
0.3
):
repeats
=
seqlen
//
16
//
2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow
,
ncol
=
seqlen
//
16
,
seqlen
//
256
mask
=
torch
.
rand
(
nrow
,
ncol
,
device
=
"cuda"
)
<
sparsity
return
mask
def
attention_blocksparse_ref
(
qkv
,
blockmask
,
attn_mask
,
dropout_p
,
dropout_mask
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q
,
k
,
v
=
qkv
.
float
().
unbind
(
dim
=
2
)
d
=
qkv
.
shape
[
-
1
]
seqlen
=
qkv
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
scores
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
blockmask
=
repeat
(
blockmask
,
"s_16 s_256 -> (s_16 16) (s_256 256)"
)
blockmask
=
blockmask
[:
seqlen
,
:
seqlen
]
scores
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention
=
attention
.
masked_fill
(
rearrange
(
~
attn_mask
,
"b s -> b 1 s 1"
),
0.0
)
attention
=
attention
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
0.0
)
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1
-
dropout_p
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
output
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b s 1 1"
),
0
)
return
output
.
to
(
dtype
=
qkv
.
dtype
),
attention
.
to
(
dtype
=
qkv
.
dtype
)
def
convert_flash_attn_S_to_softmax
(
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
nblocks_n
=
(
seqlen_k_rounded
+
blocksize_n
-
1
)
//
blocksize_n
nblocks_m
=
(
seqlen_q_rounded
+
blocksize_m
-
1
)
//
blocksize_m
mmas_n
=
(
blocksize_n
+
16
-
1
)
//
16
S_flat
=
rearrange
(
S
,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)"
,
blocksize_m
=
blocksize_m
,
blocksize_n
=
blocksize_n
,
)
S_converted
=
rearrange
(
S_flat
,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)"
,
mmas_n
=
mmas_n
,
warps_n
=
warps_n
,
eight
=
8
,
c0
=
2
,
c1
=
2
,
c2
=
2
,
four
=
4
,
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
S
.
device
,
)
local_mask
=
F
.
pad
(
local_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
)
S_converted
.
masked_fill_
(
local_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og
=
(
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q_rounded
)
if
query_padding_mask
is
not
None
:
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
seqlen_k_og
=
key_padding_mask
.
shape
[
-
1
]
if
key_padding_mask
is
not
None
else
seqlen_k
if
key_padding_mask
is
not
None
:
key_padding_mask
=
F
.
pad
(
key_padding_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
0.0
)
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
0
,
0
,
seqlen_q_og
-
seqlen_q_rounded
))
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
seqlen_k_og
-
seqlen_k_rounded
))
return
S_converted
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
def
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
is_dropout
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
seqlen_k
=
k
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse
[
lse
==
float
(
"-inf"
)]
=
float
(
"inf"
)
scores_max_block
=
torch
.
stack
(
[
torch
.
amax
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
cummax_block
=
torch
.
cummax
(
scores_max_block
.
flip
(
-
1
),
dim
=-
1
).
values
.
flip
(
-
1
).
unbind
(
dim
=-
1
)
attn_unnorm_block
=
attn_unnorm
.
split
(
block_size_n
,
dim
=-
1
)
attn_norm
=
torch
.
cat
(
[
a
*
rearrange
(
torch
.
exp
(
m
-
lse
),
"b h s -> b h s 1"
)
for
a
,
m
in
zip
(
attn_unnorm_block
,
cummax_block
)
],
dim
=-
1
,
)
if
query_padding_mask
is
not
None
:
attn_norm
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
return
attn_norm
.
to
(
dtype
=
attn_unnorm
.
dtype
)
def
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
dropped
=
~
dropout_mask
valid
=
torch
.
ones_like
(
dropout_mask
)
if
query_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
,
)
dropped
.
masked_fill_
(
local_mask
,
False
)
valid
.
masked_fill_
(
local_mask
,
False
)
dropped_total
=
dropped
.
sum
()
return
dropped
.
sum
()
/
valid
.
sum
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
)
@
pytest
.
mark
.
parametrize
(
"b_sq"
,
[
(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)
]
)
@
pytest
.
mark
.
parametrize
(
"nh_hd"
,
[
(
32
,
64
),
(
16
,
128
),
(
40
,
128
)
# non power of 2 nh
]
)
@
pytest
.
mark
.
parametrize
(
"tp_world_size"
,
[
1
,
2
,
4
]
)
def
test_flash_attn_func
(
b_sq
,
nh_hd
,
tp_world_size
,
dtype
):
b
,
sq
=
b_sq
nh
,
hd
=
nh_hd
nh_tp
=
nh
//
tp_world_size
q
,
k
,
v
=
[
torch
.
randn
(
b
,
sq
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
dout
=
torch
.
rand_like
(
q
)
for
tp_index
in
range
(
tp_world_size
):
alibi
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
sq
,
num_attention_heads
=
nh
,
tp_world_size
=
tp_world_size
,
tp_index
=
tp_index
,
key_padding_mask
=
None
,
device
=
"cuda"
)
triton_out
=
flash_attn_func_triton
(
q
,
k
,
v
,
alibi
,
True
,
hd
**
(
-
0.5
))
triton_out
.
backward
(
dout
)
triton_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
triton_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
triton_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
flash_out
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
True
,
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
b
))
flash_out
.
backward
(
dout
)
flash_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
flash_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
flash_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
assert
torch
.
allclose
(
flash_out
,
triton_out
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dq
,
triton_dq
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dk
,
triton_dk
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dv
,
triton_dv
,
atol
=
1e-2
,
rtol
=
0.
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
)
@
pytest
.
mark
.
parametrize
(
"right_padding"
,
[
True
,
False
]
)
@
pytest
.
mark
.
parametrize
(
"b_sq"
,
[
(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)
]
)
@
pytest
.
mark
.
parametrize
(
"nh_hd"
,
[
(
32
,
64
),
(
16
,
128
),
(
40
,
128
)
# non power of 2 nh
]
)
@
pytest
.
mark
.
parametrize
(
"tp_world_size"
,
[
1
,
2
,
4
]
)
def
test_flash_attn_varlen_func
(
b_sq
,
nh_hd
,
tp_world_size
,
right_padding
,
dtype
):
b
,
sqk
=
b_sq
nh
,
hd
=
nh_hd
nh_tp
=
nh
//
tp_world_size
# flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
# so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
# https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
q
=
torch
.
randn
(
b
,
1
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
k
,
v
=
[
torch
.
randn
(
b
,
sqk
,
nh_tp
,
hd
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
2
)]
dout
=
torch
.
rand_like
(
q
)
padding_mask
=
generate_random_padding_mask
(
sqk
,
b
,
"cuda"
,
"random"
,
right_padding
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
None
,
padding_mask
,
kvpacked
=
False
)
for
tp_index
in
range
(
tp_world_size
):
alibi
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
sqk
,
num_attention_heads
=
nh
,
tp_world_size
=
tp_world_size
,
tp_index
=
tp_index
,
key_padding_mask
=
padding_mask
,
device
=
"cuda"
)
triton_out
=
flash_attn_func_triton
(
q
,
k
,
v
,
alibi
,
False
,
hd
**
(
-
0.5
))
triton_out
.
backward
(
dout
)
triton_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
triton_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
triton_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
flash_out_unpad
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
=
True
,
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
b
)
)
flash_out
=
output_pad_fn
(
flash_out_unpad
)
flash_out
.
backward
(
dout
)
flash_dq_unpad
,
q_unpad
.
grad
=
q_unpad
.
grad
.
clone
(),
None
flash_dk_unpad
,
k_unpad
.
grad
=
k_unpad
.
grad
.
clone
(),
None
flash_dv_unpad
,
v_unpad
.
grad
=
v_unpad
.
grad
.
clone
(),
None
flash_dq
=
dq_pad_fn
(
flash_dq_unpad
)
flash_dk
=
dk_pad_fn
(
flash_dk_unpad
)
flash_dv
=
dk_pad_fn
(
flash_dv_unpad
)
assert
torch
.
allclose
(
flash_out
,
triton_out
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dq
,
triton_dq
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dk
,
triton_dk
,
atol
=
1e-2
,
rtol
=
0.
)
assert
torch
.
allclose
(
flash_dv
,
triton_dv
,
atol
=
1e-2
,
rtol
=
0.
)
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("local", [False, True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
# @pytest.mark.parametrize("causal", [False, True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
has_batch_idx
,
rotary_fraction
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
new_kv
,
mha_type
,
num_splits
,
dtype
,
alibi
,
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
if
not
new_kv
and
rotary_fraction
>
0.0
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
nheads
=
8
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
4
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
else
:
cache_batch_idx
=
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
q_ro
=
rearrange
(
apply_rotary_emb
(
rearrange
(
q
,
"b s h d -> b 1 (s h) d"
),
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
,
),
"b 1 (s h) d -> b s h d"
,
s
=
seqlen_q
,
)
# q_ro = q
k_ro
=
apply_rotary_emb
(
k
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]).
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k_ro
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
if
alibi
:
seqlen_alibi
=
k_cache_rep
.
shape
[
1
]
alibi_tensor
,
alibi_slopes
=
generate_alibi
(
max_seq_len
=
seqlen_alibi
,
num_attention_heads
=
nheads
,
tp_world_size
=
1
,
tp_index
=
0
,
key_padding_mask
=
None
,
device
=
"cuda"
)
# alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
alibi_slopes
=
repeat
(
alibi_slopes
,
"nh -> b nh"
,
b
=
batch_size
)
if
alibi_tensor
.
abs
().
max
().
item
()
>=
torch
.
finfo
(
dtype
).
max
:
pytest
.
skip
()
else
:
alibi_tensor
,
alibi_slopes
=
None
,
None
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cos
,
sin
,
cache_seqlens
,
cache_batch_idx
,
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
num_splits
=
num_splits
,
alibi_slopes
=
alibi_slopes
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
\
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
bias
=
alibi_tensor
)
out_pt
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
bias
=
alibi_tensor
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if
new_kv
:
k_cache_select
=
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
\
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
tests/test_flash_attn.py
View file @
5ab9b366
...
@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
...
@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
def
attn_bias_from_alibi_slopes
(
slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
):
batch
,
nheads
=
slopes
.
shape
device
=
slopes
.
device
slopes
=
rearrange
(
slopes
,
"b h -> b h 1 1"
)
if
causal
:
return
torch
.
arange
(
-
seqlen_k
+
1
,
1
,
device
=
device
,
dtype
=
torch
.
float32
)
*
slopes
else
:
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
relative_pos
=
torch
.
abs
(
row_idx
+
sk
-
sq
-
col_idx
)
return
-
slopes
*
relative_pos
.
to
(
dtype
=
slopes
.
dtype
)
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
if
mode
==
"full"
:
...
@@ -186,6 +211,7 @@ def attention_ref(
...
@@ -186,6 +211,7 @@ def attention_ref(
v
,
v
,
query_padding_mask
=
None
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
...
@@ -200,6 +226,7 @@ def attention_ref(
...
@@ -200,6 +226,7 @@ def attention_ref(
v: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
causal: whether to apply causal masking
...
@@ -238,7 +265,9 @@ def attention_ref(
...
@@ -238,7 +265,9 @@ def attention_ref(
q
.
device
,
q
.
device
,
)
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
...
@@ -264,6 +293,7 @@ def attention_kvpacked_ref(
...
@@ -264,6 +293,7 @@ def attention_kvpacked_ref(
kv
,
kv
,
query_padding_mask
=
None
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
...
@@ -277,6 +307,7 @@ def attention_kvpacked_ref(
...
@@ -277,6 +307,7 @@ def attention_kvpacked_ref(
kv
[:,
:,
1
],
kv
[:,
:,
1
],
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
upcast
=
upcast
,
upcast
=
upcast
,
...
@@ -289,6 +320,7 @@ def attention_kvpacked_ref(
...
@@ -289,6 +320,7 @@ def attention_kvpacked_ref(
def
attention_qkvpacked_ref
(
def
attention_qkvpacked_ref
(
qkv
,
qkv
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
...
@@ -302,6 +334,7 @@ def attention_qkvpacked_ref(
...
@@ -302,6 +334,7 @@ def attention_qkvpacked_ref(
qkv
[:,
:,
2
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
upcast
=
upcast
,
upcast
=
upcast
,
...
@@ -436,6 +469,7 @@ def normalize_flash_attn_S(
...
@@ -436,6 +469,7 @@ def normalize_flash_attn_S(
v
,
v
,
query_padding_mask
=
None
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
is_dropout
=
False
,
is_dropout
=
False
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
...
@@ -445,6 +479,7 @@ def normalize_flash_attn_S(
...
@@ -445,6 +479,7 @@ def normalize_flash_attn_S(
q: (batch_size, seqlen_q, nheads, head_dim)
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
...
@@ -467,6 +502,8 @@ def normalize_flash_attn_S(
...
@@ -467,6 +502,8 @@ def normalize_flash_attn_S(
q
.
device
,
q
.
device
,
)
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
.
to
(
dtype
=
scores
.
dtype
)
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
...
@@ -529,6 +566,8 @@ def get_dropout_fraction(
...
@@ -529,6 +566,8 @@ def get_dropout_fraction(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
...
@@ -538,24 +577,34 @@ def get_dropout_fraction(
...
@@ -538,24 +577,34 @@ def get_dropout_fraction(
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
384
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize("seqlen", [
128
])
# @pytest.mark.parametrize("seqlen", [
97
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
...
@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
...
@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
qkv
[:,
:,
2
],
qkv
[:,
:,
2
],
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
...
@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
dropout_mask
=
None
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
qkv
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
qkv
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -651,7 +702,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
...
@@ -651,7 +702,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
...
@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
257
,
384
,
512
,
768
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
dtype
):
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
device
=
"cuda"
...
@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
*
qkv
.
unbind
(
dim
=
2
),
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
*
qkv
.
unbind
(
dim
=
2
),
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
...
@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
...
@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
qkv
[:,
:,
2
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_mask
=
None
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
qkv
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
qkv
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -774,7 +845,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -774,7 +845,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
...
@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
...
@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
...
@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
def
test_flash_attn_output
(
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
mha_type
,
dtype
,
kvpacked
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
@@ -825,7 +900,7 @@ def test_flash_attn_output(
...
@@ -825,7 +900,7 @@ def test_flash_attn_output(
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
...
@@ -842,14 +917,32 @@ def test_flash_attn_output(
...
@@ -842,14 +917,32 @@ def test_flash_attn_output(
v
=
torch
.
randn
(
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
if
kvpacked
:
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
else
:
else
:
out
,
lse
,
S_dmask
=
flash_attn_func
(
out
,
lse
,
S_dmask
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
...
@@ -878,6 +971,7 @@ def test_flash_attn_output(
...
@@ -878,6 +971,7 @@ def test_flash_attn_output(
v_rep
,
v_rep
,
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -895,6 +989,7 @@ def test_flash_attn_output(
...
@@ -895,6 +989,7 @@ def test_flash_attn_output(
kv
,
kv
,
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -905,6 +1000,7 @@ def test_flash_attn_output(
...
@@ -905,6 +1000,7 @@ def test_flash_attn_output(
kv
,
kv
,
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -919,6 +1015,7 @@ def test_flash_attn_output(
...
@@ -919,6 +1015,7 @@ def test_flash_attn_output(
v
,
v
,
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -930,6 +1027,7 @@ def test_flash_attn_output(
...
@@ -930,6 +1027,7 @@ def test_flash_attn_output(
v
,
v
,
None
,
None
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1000,7 +1098,9 @@ def test_flash_attn_output(
...
@@ -1000,7 +1098,9 @@ def test_flash_attn_output(
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
@@ -1014,11 +1114,13 @@ def test_flash_attn_output(
...
@@ -1014,11 +1114,13 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize('mha_type', ["mqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1041,7 +1143,7 @@ def test_flash_attn_output(
...
@@ -1041,7 +1143,7 @@ def test_flash_attn_output(
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
mha_type
,
dtype
,
kvpacked
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output(
...
@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output(
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads
=
9
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
...
@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output(
...
@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output(
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
if
kvpacked
:
(
(
...
@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output(
...
@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
else
:
else
:
(
(
...
@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output(
...
@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
...
@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output(
...
@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output(
v_rep
,
v_rep
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output(
...
@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output(
kv
,
kv
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output(
...
@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output(
kv
,
kv
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output(
...
@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output(
v
,
v
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output(
...
@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output(
v
,
v
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_p
,
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output(
...
@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output(
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
...
@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
...
@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
causal
=
True
causal
=
True
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
...
@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
q
,
k
,
v
,
None
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
...
@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v
,
v
,
None
,
None
,
None
,
None
,
None
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
...
@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
causal
=
True
causal
=
True
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
13
batch_size
=
8
nheads
=
9
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
...
@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v
,
v
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
None
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
...
@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v
,
v
,
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
None
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
...
@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [
Tru
e])
# @pytest.mark.parametrize("local", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
...
@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
...
@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
alibi
,
dtype
):
if
swap_sq_sk
:
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
device
=
"cuda"
...
@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
_
=
flash_attn_func
(
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
return_attn_probs
=
True
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_attn_probs
=
True
,
)
)
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
q
,
k
,
v
,
None
,
None
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
v
,
v
,
None
,
None
,
None
,
None
,
attn_bias
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
# of a Pytorch implementation.
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult
=
2
if
not
alibi
else
8
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
#
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
#
@pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [
0
])
# @pytest.mark.parametrize("num_splits", [
1
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [
Tru
e])
# @pytest.mark.parametrize("causal", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
...
@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [
Tru
e])
# @pytest.mark.parametrize("has_batch_idx", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache(
...
@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache(
seqlen_new_eq_seqlen_q
,
seqlen_new_eq_seqlen_q
,
causal
,
causal
,
local
,
local
,
alibi
,
new_kv
,
new_kv
,
mha_type
,
mha_type
,
num_splits
,
num_splits
,
...
@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache(
...
@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
if
has_batch_idx
:
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
:
batch_size
]
else
:
else
:
cache_batch_idx
=
None
cache_batch_idx
=
None
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
None
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
...
@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache(
...
@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache(
# k_cache[:, 64:] = -1
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]).
clone
()
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]).
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
if
new_kv
:
update_mask
=
torch
.
logical_and
(
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
...
@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache(
...
@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache(
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
rotary_interleaved
=
rotary_interleaved
,
alibi_slopes
=
alibi_slopes
,
num_splits
=
num_splits
,
num_splits
=
num_splits
,
)
)
# out = flash_attn_with_kvcache(
# out = flash_attn_with_kvcache(
...
@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache(
...
@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache(
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
out_ref
,
_
=
attention_ref
(
q_ro
,
q_ro
,
k_cache_rep
,
k_cache_rep
,
v_cache_rep
,
v_cache_rep
,
None
,
None
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache(
...
@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache(
v_cache_rep
,
v_cache_rep
,
None
,
None
,
key_padding_mask
,
key_padding_mask
,
attn_bias
,
0.0
,
0.0
,
None
,
None
,
causal
=
causal
,
causal
=
causal
,
...
@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache(
...
@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache(
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult
=
3
if
not
alibi
else
5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
mult
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment