Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5ab9b366
Commit
5ab9b366
authored
Dec 21, 2023
by
Tri Dao
Browse files
Clean up alibi, implement non-causal alibi
parent
bc28eacc
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
309 additions
and
1199 deletions
+309
-1199
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+27
-32
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+27
-19
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+0
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+10
-30
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+3
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-40
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+2
-2
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+5
-7
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+25
-6
tests/test_alibi.py
tests/test_alibi.py
+0
-1006
tests/test_flash_attn.py
tests/test_flash_attn.py
+202
-49
No files found.
csrc/flash_attn/flash_api.cpp
View file @
5ab9b366
...
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
if
(
seqlen_k
>
0
)
{
...
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
...
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
if
(
seqlen_q
>
0
)
{
...
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
float
p_dropout
,
// probability to drop
...
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
launch
(
params
,
stream
,
/*configure=*/
false
);
...
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
// batch_size x num_heads
int
num_splits
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
params
.
has_alibi
=
true
;
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
csrc/flash_attn/src/alibi.h
View file @
5ab9b366
...
...
@@ -13,22 +13,32 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
template
<
bool
Is_causal
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
head_idx
,
const
float
softmax_scale
,
const
float
alibi_slope
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
float
alibi_slope_unscaled
=
alibi_slope
/
softmax_scale
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
}
}
}
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
...
...
@@ -41,9 +51,7 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
const
float
alibi
=
alibi_slope_unscaled
*
col_idx
;
if
(
col_idx
<
max_seqlen_k
&&
row_idx
<
max_seqlen_q
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
+=
alibi
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
}
...
...
csrc/flash_attn/src/flash.h
View file @
5ab9b366
...
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int
num_splits
;
// For split-KV version
// float alibi_start;
// float alibi_ratio;
bool
has_alibi
;
void
*
__restrict__
alibi_slopes_ptr
;
index_t
alibi_slopes_batch_stride
;
};
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5ab9b366
...
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
INFINITY
;
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dk
);
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
...
...
@@ -830,14 +823,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
clear
(
acc_dq
);
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
...
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5ab9b366
...
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
5ab9b366
...
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// can produce Inf / NaN.
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
);
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5ab9b366
...
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
...
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
...
...
csrc/flash_attn/src/softmax.h
View file @
5ab9b366
...
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
...
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
_
,
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
flash_attn/flash_attn_interface.py
View file @
5ab9b366
...
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
k
,
v
,
None
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
None
,
)
...
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q
,
cu_seqlens_k
,
None
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
...
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
None
,
)
...
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq
,
dk
,
dv
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
None
,
rng_state
,
)
...
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
...
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
None
,
rng_state
,
)
...
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
...
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
num_splits
=
0
,
alibi_slopes
=
None
,
num_splits
=
0
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
...
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos
,
rotary_sin
,
cache_batch_idx
,
alibi_slopes
,
None
,
softmax_scale
,
causal
,
...
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size
[
1
],
rotary_interleaved
,
num_splits
,
alibi_slopes
,
)
return
out
tests/test_alibi.py
deleted
100644 → 0
View file @
bc28eacc
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
5ab9b366
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment