Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5ab9b366
Commit
5ab9b366
authored
Dec 21, 2023
by
Tri Dao
Browse files
Clean up alibi, implement non-causal alibi
parent
bc28eacc
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
309 additions
and
1199 deletions
+309
-1199
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+27
-32
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+27
-19
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+0
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+10
-30
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+3
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-40
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+2
-2
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+5
-7
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+25
-6
tests/test_alibi.py
tests/test_alibi.py
+0
-1006
tests/test_flash_attn.py
tests/test_flash_attn.py
+202
-49
No files found.
csrc/flash_attn/flash_api.cpp
View file @
5ab9b366
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
if
(
seqlen_k
>
0
)
{
if
(
seqlen_k
>
0
)
{
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
const
float
p_dropout
,
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dq_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dk_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// batch_size x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
if
(
seqlen_q
>
0
)
{
if
(
seqlen_q
>
0
)
{
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10
::
optional
<
at
::
Tensor
>
&
dv_
,
// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
int
max_seqlen_k
,
// max sequence length to choose the kernel
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// b x num_heads
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
launch
(
params
,
stream
,
/*configure=*/
false
);
launch
(
params
,
stream
,
/*configure=*/
false
);
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
const
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
,
int
num_splits
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
// batch_size x num_heads
)
{
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
head_size_og
%
8
==
0
&&
!
(
alibi_slopes_
.
has_value
());
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
,
"ALiBi slopes must have dtype fp32"
);
CHECK_DEVICE
(
alibi_slopes
);
CHECK_DEVICE
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
TORCH_CHECK
(
alibi_slopes
.
stride
(
-
1
)
==
1
,
"ALiBi slopes tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
alibi_slopes
,
batch_size
,
num_heads
);
TORCH_CHECK
(
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
num_heads
})
||
alibi_slopes
.
sizes
()
==
torch
::
IntArrayRef
({
batch_size
,
num_heads
}));
params
.
has_alibi
=
true
;
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_ptr
=
alibi_slopes
.
data_ptr
();
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
stride
(
0
);
params
.
alibi_slopes_batch_stride
=
alibi_slopes
.
dim
()
==
2
?
alibi_slopes
.
stride
(
0
)
:
0
;
}
else
{
}
else
{
params
.
has_
alibi
=
false
;
params
.
alibi
_slopes_ptr
=
nullptr
;
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
csrc/flash_attn/src/alibi.h
View file @
5ab9b366
...
@@ -13,37 +13,45 @@ using namespace cute;
...
@@ -13,37 +13,45 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
template
<
bool
Is_causal
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
warp_row_stride
,
const
int
head_idx
,
const
float
softmax_scale
,
const
float
alibi_slope
)
{
const
float
alibi_slope
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
float
alibi_slope_unscaled
=
alibi_slope
/
softmax_scale
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
row
_idx
=
row
_idx_
ba
se
+
i
*
8
;
const
int
col
_idx
_base
=
col
_idx_
off
se
t
+
nj
*
8
;
#pragma unroll
#pragma unroll
for
(
int
n
j
=
0
;
n
j
<
size
<
1
,
1
>
(
tensor
);
++
n
j
)
{
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
_base
=
col_idx_
off
se
t
+
nj
*
8
;
const
int
col_idx
=
col_idx_
ba
se
+
j
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
const
float
alibi
=
alibi_slope_unscaled
*
col_idx
;
}
if
(
col_idx
<
max_seqlen_k
&&
row_idx
<
max_seqlen_q
)
{
}
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
+=
alibi
;
}
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
}
}
}
}
...
@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
...
@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
}
}
}
}
}
// namespace flash
}
// namespace flash
\ No newline at end of file
csrc/flash_attn/src/flash.h
View file @
5ab9b366
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int
num_splits
;
// For split-KV version
int
num_splits
;
// For split-KV version
// float alibi_start;
// float alibi_ratio;
bool
has_alibi
;
void
*
__restrict__
alibi_slopes_ptr
;
void
*
__restrict__
alibi_slopes_ptr
;
index_t
alibi_slopes_batch_stride
;
index_t
alibi_slopes_batch_stride
;
};
};
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5ab9b366
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
INFINITY
;
}
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK);
// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dv
);
clear
(
acc_dk
);
clear
(
acc_dk
);
float
alibi_slope
=
0.0
f
;
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
...
@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
// if (cute::thread(32, 0)) { print(scores); }
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
clear
(
acc_dq
);
clear
(
acc_dq
);
float
alibi_slope
=
0.0
f
;
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
}
for
(;
n_block
>=
0
;
--
n_block
)
{
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
AtomLayoutMS
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5ab9b366
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
&&
!
Is_local
,
Is_local
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
5ab9b366
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// can produce Inf / NaN.
// can produce Inf / NaN.
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
// For performance reason, we separate out two kinds of iterations:
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// We will have at least 1 "masking" iteration.
float
alibi_slope
=
0.0
f
;
if
(
Has_alibi
)
{
Tensor
gAS
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)
+
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
),
Shape
<
_1
>
{});
Tensor
rAS
=
make_fragment_like
(
gAS
);
cute
::
copy
(
gAS
,
rAS
);
alibi_slope
=
rAS
(
0
);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
if
(
Has_alibi
)
{
flash
::
apply_alibi
(
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
scores
,
n_block
*
kBlockN
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
kNWarps
*
16
,
bidh
,
params
.
scale_softmax
,
alibi_slope
alibi_slope
);
);
}
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
5ab9b366
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
has_
alibi
,
Has_alibi
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi
_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
...
...
csrc/flash_attn/src/softmax.h
View file @
5ab9b366
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
...
@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
...
@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
_
,
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
...
@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
flash_attn/flash_attn_interface.py
View file @
5ab9b366
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
...
@@ -53,12 +53,12 @@ def _flash_attn_forward(
k
,
k
,
v
,
v
,
None
,
None
,
alibi_slopes
,
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
...
@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
None
,
None
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
...
@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
...
@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq
,
dq
,
dk
,
dk
,
dv
,
dv
,
alibi_slopes
,
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
...
@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv
,
dv
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
dropout_p
,
dropout_p
,
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
...
@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
alibi_slopes
,
None
,
None
,
rng_state
,
rng_state
,
)
)
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
...
@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
...
@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -732,6 +737,9 @@ def flash_attn_func(
...
@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
...
@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
...
@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
num_splits
=
0
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
num_splits
=
0
,
):
):
"""
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
...
@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
to automatically determine the number of splits.
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
...
@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos
,
rotary_cos
,
rotary_sin
,
rotary_sin
,
cache_batch_idx
,
cache_batch_idx
,
alibi_slopes
,
None
,
None
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
...
@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size
[
1
],
window_size
[
1
],
rotary_interleaved
,
rotary_interleaved
,
num_splits
,
num_splits
,
alibi_slopes
,
)
)
return
out
return
out
tests/test_alibi.py
deleted
100644 → 0
View file @
bc28eacc
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
5ab9b366
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment