Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ccbb14f3
Commit
ccbb14f3
authored
Sep 16, 2023
by
Tri Dao
Browse files
Implement rotary embedding in flash_attn_with_kvcache
parent
5400fdc4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
397 additions
and
99 deletions
+397
-99
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+58
-42
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+7
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+90
-8
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+12
-3
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+127
-31
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+1
-1
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+32
-2
tests/models/test_gpt.py
tests/models/test_gpt.py
+1
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+69
-10
No files found.
csrc/flash_attn/flash_api.cpp
View file @
ccbb14f3
...
...
@@ -13,7 +13,9 @@
#include "flash.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
void
set_params_fprop
(
Flash_fwd_params
&
params
,
...
...
@@ -260,9 +262,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
...
...
@@ -299,7 +299,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
...
...
@@ -426,17 +426,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
cu_seqlens_q
.
dtype
()
==
torch
::
kInt32
,
"cu_seqlens_q must have dtype int32"
);
TORCH_CHECK
(
cu_seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"cu_seqlens_k must have dtype int32"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
cu_seqlens_q
.
is_cuda
(),
"cu_seqlens_q must be on CUDA device"
);
TORCH_CHECK
(
cu_seqlens_k
.
is_cuda
(),
"cu_seqlens_k must be on CUDA device"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
cu_seqlens_q
);
CHECK_DEVICE
(
cu_seqlens_k
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
cu_seqlens_q
.
is_contiguous
(),
"cu_seqlens_q must be contiguous"
);
TORCH_CHECK
(
cu_seqlens_k
.
is_contiguous
(),
"cu_seqlens_k must be contiguous"
);
CHECK_CONTIGUOUS
(
cu_seqlens_q
);
CHECK_CONTIGUOUS
(
cu_seqlens_k
);
const
auto
sizes
=
q
.
sizes
();
...
...
@@ -471,7 +469,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
total_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
...
...
@@ -610,12 +608,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"query and out must have the same dtype"
);
TORCH_CHECK
(
dout
.
dtype
()
==
q_dtype
,
"query and dout must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"out tensor must be on CUDA device"
);
TORCH_CHECK
(
dout
.
is_cuda
(),
"dout tensor must be on CUDA device"
);
TORCH_CHECK
(
softmax_lse
.
is_cuda
(),
"softmax_lse tensor must be on CUDA device"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
out
);
CHECK_DEVICE
(
dout
);
CHECK_DEVICE
(
softmax_lse
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
...
...
@@ -657,7 +651,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if
(
dq_
.
has_value
())
{
dq
=
dq_
.
value
();
TORCH_CHECK
(
dq
.
dtype
()
==
q_dtype
,
"dq must have the same dtype as q"
);
TORCH_CHECK
(
dq
.
is_cuda
(),
"dq must be on CUDA device"
);
CHECK_DEVICE
(
dq
);
TORCH_CHECK
(
dq
.
stride
(
-
1
)
==
1
,
"dq must have contiguous last dimension"
);
CHECK_SHAPE
(
dq
,
batch_size
,
seqlen_q
,
num_heads
,
head_size
);
}
else
{
...
...
@@ -666,7 +660,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if
(
dk_
.
has_value
())
{
dk
=
dk_
.
value
();
TORCH_CHECK
(
dk
.
dtype
()
==
q_dtype
,
"dk must have the same dtype as q"
);
TORCH_CHECK
(
dk
.
is_cuda
(),
"dk must be on CUDA device"
);
CHECK_DEVICE
(
dk
);
TORCH_CHECK
(
dk
.
stride
(
-
1
)
==
1
,
"dk must have contiguous last dimension"
);
CHECK_SHAPE
(
dk
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
}
else
{
...
...
@@ -675,7 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if
(
dv_
.
has_value
())
{
dv
=
dv_
.
value
();
TORCH_CHECK
(
dv
.
dtype
()
==
q_dtype
,
"dv must have the same dtype as q"
);
TORCH_CHECK
(
dv
.
is_cuda
(),
"dv must be on CUDA device"
);
CHECK_DEVICE
(
dv
);
TORCH_CHECK
(
dv
.
stride
(
-
1
)
==
1
,
"dv must have contiguous last dimension"
);
CHECK_SHAPE
(
dv
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
}
else
{
...
...
@@ -820,22 +814,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
cu_seqlens_q
.
dtype
()
==
torch
::
kInt32
,
"cu_seqlens_q must have dtype int32"
);
TORCH_CHECK
(
cu_seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"cu_seqlens_k must have dtype int32"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"out tensor must be on CUDA device"
);
TORCH_CHECK
(
dout
.
is_cuda
(),
"dout tensor must be on CUDA device"
);
TORCH_CHECK
(
softmax_lse
.
is_cuda
(),
"softmax_lse tensor must be on CUDA device"
);
TORCH_CHECK
(
cu_seqlens_q
.
is_cuda
(),
"cu_seqlens_q must be on CUDA device"
);
TORCH_CHECK
(
cu_seqlens_k
.
is_cuda
(),
"cu_seqlens_k must be on CUDA device"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
out
);
CHECK_DEVICE
(
dout
);
CHECK_DEVICE
(
softmax_lse
);
CHECK_DEVICE
(
cu_seqlens_q
);
CHECK_DEVICE
(
cu_seqlens_k
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"out tensor must have contiguous last dimension"
);
TORCH_CHECK
(
dout
.
stride
(
-
1
)
==
1
,
"dout tensor must have contiguous last dimension"
);
TORCH_CHECK
(
cu_seqlens_q
.
is_contiguous
(),
"cu_seqlens_q must be contiguous"
);
TORCH_CHECK
(
cu_seqlens_k
.
is_contiguous
(),
"cu_seqlens_k must be contiguous"
);
CHECK_CONTIGUOUS
(
cu_seqlens_q
);
CHECK_CONTIGUOUS
(
cu_seqlens_k
);
const
auto
sizes
=
q
.
sizes
();
...
...
@@ -873,7 +862,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
if
(
dq_
.
has_value
())
{
dq
=
dq_
.
value
();
TORCH_CHECK
(
dq
.
dtype
()
==
q_dtype
,
"dq must have the same dtype as q"
);
TORCH_CHECK
(
dq
.
is_cuda
(),
"dq must be on CUDA device"
);
CHECK_DEVICE
(
dq
);
TORCH_CHECK
(
dq
.
stride
(
-
1
)
==
1
,
"dq must have contiguous last dimension"
);
CHECK_SHAPE
(
dq
,
total_q
,
num_heads
,
head_size
);
}
else
{
...
...
@@ -882,7 +871,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
if
(
dk_
.
has_value
())
{
dk
=
dk_
.
value
();
TORCH_CHECK
(
dk
.
dtype
()
==
q_dtype
,
"dk must have the same dtype as q"
);
TORCH_CHECK
(
dk
.
is_cuda
(),
"dk must be on CUDA device"
);
CHECK_DEVICE
(
dk
);
TORCH_CHECK
(
dk
.
stride
(
-
1
)
==
1
,
"dk must have contiguous last dimension"
);
CHECK_SHAPE
(
dk
,
total_k
,
num_heads_k
,
head_size
);
}
else
{
...
...
@@ -891,7 +880,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
if
(
dv_
.
has_value
())
{
dv
=
dv_
.
value
();
TORCH_CHECK
(
dv
.
dtype
()
==
q_dtype
,
"dv must have the same dtype as q"
);
TORCH_CHECK
(
dv
.
is_cuda
(),
"dv must be on CUDA device"
);
CHECK_DEVICE
(
dv
);
TORCH_CHECK
(
dv
.
stride
(
-
1
)
==
1
,
"dv must have contiguous last dimension"
);
CHECK_SHAPE
(
dv
,
total_k
,
num_heads_k
,
head_size
);
}
else
{
...
...
@@ -1000,9 +989,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
bool
is_causal
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
)
{
...
...
@@ -1023,9 +1015,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
vcache
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
kcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
vcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
kcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
...
...
@@ -1071,7 +1061,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
...
...
@@ -1118,8 +1108,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
v
=
v_
.
value
();
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"Key must have the same dtype as query"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"Value must have the same dtype as query"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Key tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Value tensor must be on CUDA device"
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Key tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Value tensor must have contiguous last dimension"
);
int
seqlen_knew
=
k
.
size
(
1
);
...
...
@@ -1147,13 +1136,40 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if
(
seqlens_k_
.
has_value
())
{
auto
seqlens_k
=
seqlens_k_
.
value
();
TORCH_CHECK
(
seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"seqlens_k must have dtype int32"
);
TORCH_CHECK
(
seqlens_k
.
is_cuda
(),
"seqlens_k must be on CUDA device"
);
TORCH_CHECK
(
seqlens_k
.
is_contiguous
(),
"seqlens_k must be contiguous"
);
CHECK_DEVICE
(
seqlens_k
);
CHECK_CONTIGUOUS
(
seqlens_k
);
CHECK_SHAPE
(
seqlens_k
,
batch_size
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
}
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
if
(
rotary_cos_
.
has_value
())
{
TORCH_CHECK
(
k_
.
has_value
(),
"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"
);
auto
rotary_cos
=
rotary_cos_
.
value
();
CHECK_DEVICE
(
rotary_cos
);
params
.
rotary_dim
=
rotary_cos
.
size
(
1
)
*
2
;
TORCH_CHECK
(
params
.
rotary_dim
<=
head_size
,
"rotary_dim must be <= headdim"
);
TORCH_CHECK
(
params
.
rotary_dim
%
16
==
0
,
"Only rotary dimensions divisible by 16 are currently supported"
);
const
int
seqlen_ro
=
rotary_cos
.
size
(
0
);
TORCH_CHECK
(
seqlen_ro
>=
seqlen_k
,
"cos/sin seqlen must be at least the seqlen of KV cache"
);
CHECK_SHAPE
(
rotary_cos
,
seqlen_ro
,
params
.
rotary_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_cos
);
TORCH_CHECK
(
rotary_cos
.
scalar_type
()
==
q_dtype
,
"rotary_cos must have the same dtype as query"
);
TORCH_CHECK
(
rotary_sin_
.
has_value
(),
"If rotary cos is provided, rotary sin must also be provided"
);
auto
rotary_sin
=
rotary_sin_
.
value
();
CHECK_DEVICE
(
rotary_sin
);
CHECK_SHAPE
(
rotary_sin
,
seqlen_ro
,
params
.
rotary_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_sin
);
TORCH_CHECK
(
rotary_sin
.
scalar_type
()
==
q_dtype
,
"rotary_cos must have the same dtype as query"
);
params
.
rotary_cos_ptr
=
rotary_cos
.
data_ptr
();
params
.
rotary_sin_ptr
=
rotary_sin
.
data_ptr
();
params
.
is_rotary_interleaved
=
is_rotary_interleaved
;
}
else
{
params
.
rotary_dim
=
0
;
}
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
...
...
csrc/flash_attn/src/flash.h
View file @
ccbb14f3
...
...
@@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void
*
__restrict__
softmax_lseaccum_ptr
;
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
seqlen_knew
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
int
b
,
seqlen_q
,
seqlen_k
,
seqlen_knew
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
,
rotary_dim
;
// The scaling factors for the kernel.
float
scale_softmax
;
...
...
@@ -91,6 +91,10 @@ struct Flash_fwd_params : public Qkv_params {
index_t
knew_head_stride
;
index_t
vnew_head_stride
;
// The cos and sin matrices for rotary embedding.
void
*
__restrict__
rotary_cos_ptr
;
void
*
__restrict__
rotary_sin_ptr
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// uint32_t p_dropout_in_uint;
...
...
@@ -114,6 +118,8 @@ struct Flash_fwd_params : public Qkv_params {
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool
is_seqlens_k_cumulative
;
bool
is_rotary_interleaved
;
int
num_splits
;
// For split-KV version
};
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
ccbb14f3
...
...
@@ -744,10 +744,36 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Prologue
// Copy from Knew to K, optionally apply rotary embedding.
typename
Kernel_traits
::
GmemTiledCopyRotcossin
gmem_tiled_copy_rotary
;
auto
gmem_thr_copy_rotary
=
gmem_tiled_copy_rotary
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyRotcossinCont
gmem_tiled_copy_rotary_cont
;
auto
gmem_thr_copy_rotary_cont
=
gmem_tiled_copy_rotary_cont
.
get_thread_slice
(
tidx
);
if
constexpr
(
Append_KV
)
{
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const
index_t
row_offset_cossin
=
((
n_block_max
-
1
)
*
kBlockN
)
*
(
params
.
rotary_dim
/
2
);
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
Tensor
gSin
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
Tensor
gCosCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
Tensor
tRgSinCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gSinCont
);
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const
index_t
row_offset_knew
=
binfo
.
k_offset
(
params
.
knew_batch_stride
,
params
.
knew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
const
index_t
row_offset_vnew
=
binfo
.
k_offset
(
params
.
vnew_batch_stride
,
params
.
vnew_row_stride
,
bidb
)
...
...
@@ -769,17 +795,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
for
(
int
n_block
=
n_block_max
-
1
;
n_block
>=
n_block_copy_min
;
n_block
--
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKgKnew
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVgVnew
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
if
(
params
.
rotary_dim
==
0
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKgKnew
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
else
{
if
(
params
.
is_rotary_interleaved
)
{
// Don't clear OOB_K because we're writing to global memory
flash
::
copy_rotary_interleaved
<
Is_even_K
,
/*Clear_OOB_K=*/
false
>
(
tKgKnew
,
tKgK
,
tRgCos
,
tRgSin
,
tKVcKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
,
params
.
d
,
params
.
rotary_dim
);
tRgCos
.
data
()
=
tRgCos
.
data
()
+
(
-
int
(
kBlockN
*
params
.
rotary_dim
/
2
));
tRgSin
.
data
()
=
tRgSin
.
data
()
+
(
-
int
(
kBlockN
*
params
.
rotary_dim
/
2
));
}
else
{
// Don't clear OOB_K because we're writing to global memory
flash
::
copy_rotary_contiguous
<
Is_even_K
,
/*Clear_OOB_K=*/
false
>
(
tKgKnew
,
tKgK
,
tRgCosCont
,
tRgSinCont
,
tKVcKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
,
params
.
d
,
params
.
rotary_dim
);
tRgCosCont
.
data
()
=
tRgCosCont
.
data
()
+
(
-
int
(
kBlockN
*
params
.
rotary_dim
/
2
));
tRgSinCont
.
data
()
=
tRgSinCont
.
data
()
+
(
-
int
(
kBlockN
*
params
.
rotary_dim
/
2
));
}
}
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
}
// Need this before we can read in K again, so that we'll see the updated K values.
__syncthreads
();
if
(
n_block_max
>
n_block_copy_min
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
n_block_max
-
n_block_copy_min
)
*
kBlockN
*
params
.
k_row_stride
;
...
...
@@ -787,10 +835,44 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
}
}
// Read Q from gmem to smem, optionally apply rotary embedding.
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
!
Append_KV
||
params
.
rotary_dim
==
0
)
{
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSin
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gCosCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Is_causal
?
params
.
rotary_dim
/
2
:
0
,
_1
{}));
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
Tensor
tRgSinCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gSinCont
);
if
(
params
.
is_rotary_interleaved
)
{
flash
::
copy_rotary_interleaved
<
Is_even_K
>
(
tQgQ
,
tQsQ
,
tRgCos
,
tRgSin
,
tQcQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
,
0
,
params
.
d
,
params
.
rotary_dim
);
}
else
{
flash
::
copy_rotary_contiguous
<
Is_even_K
>
(
tQgQ
,
tQsQ
,
tRgCosCont
,
tRgSinCont
,
tQcQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
,
0
,
params
.
d
,
params
.
rotary_dim
);
}
}
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
ccbb14f3
...
...
@@ -142,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
e
lem
_type
>
{},
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
E
lem
ent
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
e
lem
_type
>
{},
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
E
lem
ent
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
static
constexpr
int
kGmemThreadsPerRowP
=
kBlockN
/
kGmemElemsPerLoad
;
...
...
@@ -155,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride
<
Int
<
kGmemThreadsPerRowP
>
,
_1
>>
;
using
GmemTiledCopyP
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
e
lem
_type
>
{},
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
E
lem
ent
>
{},
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
...
...
@@ -170,6 +170,15 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
using
GmemLayoutAtomRotcossin
=
GmemLayoutAtom
;
using
GmemTiledCopyRotcossin
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per load
using
GmemTiledCopyRotcossinCont
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
...
...
csrc/flash_attn/src/utils.h
View file @
ccbb14f3
...
...
@@ -355,65 +355,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_
2_sources
=
false
,
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB
_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
template
<
bool
Is_
even
_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_2_sources
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S0
,
Tensor
<
Engine0
,
Layout0
>
const
&
S1
,
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
row_idx_switch
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
0
)
==
Int
<
3
>
{}
&&
rank
(
S1
)
==
Int
<
3
>
{});
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S0
)
==
size
<
0
>
(
D
)
&&
size
<
0
>
(
S1
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S0
)
==
size
<
1
>
(
D
)
&&
size
<
1
>
(
S1
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S0
)
==
size
<
2
>
(
D
)
&&
size
<
2
>
(
S1
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S0
);
++
m
)
{
auto
&
S
=
!
Is_2_sources
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
row_idx_switch
?
S0
:
S1
;
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
0
);
++
k
)
{
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_rotary_interleaved
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine3
,
Layout3
>
const
&
identity_MN
,
const
int
max_MN
,
const
int
min_MN
,
const
int
dim
,
const
int
rotary_dim
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Cos
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Cos
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Sin
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Sin
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
Cos
)
==
size
<
0
>
(
Sin
));
// MMA_K
static_assert
(
decltype
(
size
<
0
>
(
S
))
::
value
==
decltype
(
size
<
0
>
(
Cos
))
::
value
*
2
);
static_assert
(
decltype
(
size
<
0
>
(
Cos
))
::
value
%
2
==
0
);
// Since we do fast conversion from fp16/bf16 to fp32
Tensor
rCos
=
make_fragment_like
(
Cos
);
Tensor
rSin
=
make_fragment_like
(
Sin
);
Tensor
rS
=
make_fragment_like
(
S
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
dim
)
{
cute
::
copy
(
S
(
_
,
m
,
k
),
rS
(
_
,
m
,
k
));
if
(
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
)
{
cute
::
copy
(
Cos
(
_
,
m
,
k
),
rCos
(
_
,
m
,
k
));
cute
::
copy
(
Sin
(
_
,
m
,
k
),
rSin
(
_
,
m
,
k
));
Tensor
S_fp32
=
convert_type
<
float
>
(
rS
(
_
,
m
,
k
));
Tensor
cos_fp32
=
convert_type
<
float
>
(
rCos
(
_
,
m
,
k
));
Tensor
sin_fp32
=
convert_type
<
float
>
(
rSin
(
_
,
m
,
k
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
rS
)
/
2
;
++
i
)
{
float
real
=
S_fp32
(
2
*
i
)
*
cos_fp32
(
i
)
-
S_fp32
(
2
*
i
+
1
)
*
sin_fp32
(
i
);
float
imag
=
S_fp32
(
2
*
i
)
*
sin_fp32
(
i
)
+
S_fp32
(
2
*
i
+
1
)
*
cos_fp32
(
i
);
S_fp32
(
2
*
i
)
=
real
;
S_fp32
(
2
*
i
+
1
)
=
imag
;
}
// Idk but I need to copy for the convert_type to work
Tensor
S_fp32_copy
=
make_fragment_like
(
S_fp32
);
cute
::
copy
(
S_fp32
,
S_fp32_copy
);
using
T
=
typename
Engine0
::
value_type
;
Tensor
S_og_type
=
convert_type
<
T
>
(
S_fp32_copy
);
cute
::
copy
(
S_og_type
,
rS
(
_
,
m
,
k
));
}
cute
::
copy
(
rS
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
inline
__device__
void
copy_rotary_contiguous
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine3
,
Layout3
>
const
&
identity_MN
,
const
int
max_MN
,
const
int
min_MN
,
const
int
dim
,
const
int
rotary_dim
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Cos
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Cos
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Sin
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Sin
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
Cos
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
Cos
)
==
size
<
0
>
(
Sin
));
static_assert
(
decltype
(
size
<
0
>
(
Cos
))
::
value
%
2
==
0
);
// Since we do fast conversion from fp16/bf16 to fp32
Tensor
rCos
=
make_fragment_like
(
Cos
);
Tensor
rSin
=
make_fragment_like
(
Sin
);
Tensor
rS
=
make_fragment_like
(
S
);
Tensor
rS_other
=
make_fragment_like
(
rS
(
_
,
0
,
0
));
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
if
(
Is_even_K
||
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
dim
)
{
cute
::
copy
(
S
(
_
,
m
,
k
),
rS
(
_
,
m
,
k
));
if
(
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
)
{
const
bool
is_left
=
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
/
2
;
Tensor
gS_other
=
make_tensor
(
S
(
_
,
m
,
k
).
data
()
+
(
is_left
?
rotary_dim
/
2
:
-
rotary_dim
/
2
),
S
(
_
,
m
,
k
).
layout
());
cute
::
copy
(
gS_other
,
rS_other
);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor
gCos
=
make_tensor
(
Cos
(
_
,
m
,
k
).
data
()
+
(
is_left
?
0
:
-
rotary_dim
/
2
),
Cos
(
_
,
m
,
k
).
layout
());
Tensor
gSin
=
make_tensor
(
Sin
(
_
,
m
,
k
).
data
()
+
(
is_left
?
0
:
-
rotary_dim
/
2
),
Sin
(
_
,
m
,
k
).
layout
());
cute
::
copy
(
gCos
,
rCos
(
_
,
m
,
k
));
cute
::
copy
(
gSin
,
rSin
(
_
,
m
,
k
));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor
S_fp32
=
convert_type
<
float
>
(
rS
(
_
,
m
,
k
));
Tensor
S_other_fp32
=
convert_type
<
float
>
(
rS_other
);
Tensor
cos_fp32
=
convert_type
<
float
>
(
rCos
(
_
,
m
,
k
));
Tensor
sin_fp32
=
convert_type
<
float
>
(
rSin
(
_
,
m
,
k
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
rS
);
++
i
)
{
S_fp32
(
i
)
=
S_fp32
(
i
)
*
cos_fp32
(
i
)
+
S_other_fp32
(
i
)
*
(
is_left
?
-
sin_fp32
(
i
)
:
sin_fp32
(
i
));
}
// Idk but I need to copy for the convert_type to work
Tensor
S_fp32_copy
=
make_fragment_like
(
S_fp32
);
cute
::
copy
(
S_fp32
,
S_fp32_copy
);
using
T
=
typename
Engine0
::
value_type
;
Tensor
S_og_type
=
convert_type
<
T
>
(
S_fp32_copy
);
cute
::
copy
(
S_og_type
,
rS
(
_
,
m
,
k
));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute
::
copy
(
rS
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
...
...
@@ -422,4 +518,4 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
\ No newline at end of file
}
// namespace flash
csrc/ft_attention/ft_attention.cpp
View file @
ccbb14f3
...
...
@@ -175,7 +175,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK
(
rotary_sin_
.
has_value
());
auto
rotary_sin
=
rotary_sin_
.
value
();
CHECK_DEVICE
(
rotary_sin
);
CHECK_SHAPE
(
rotary_
cos
,
batch_size
,
rotary_embedding_dim
/
2
);
CHECK_SHAPE
(
rotary_
sin
,
batch_size
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_sin
);
TORCH_CHECK
(
rotary_sin
.
scalar_type
()
==
input_type
);
}
...
...
flash_attn/flash_attn_interface.py
View file @
ccbb14f3
...
...
@@ -800,9 +800,12 @@ def flash_attn_with_kvcache(
v_cache
,
k
=
None
,
v
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
rotary_interleaved
=
True
,
num_splits
=
0
,
):
"""
...
...
@@ -815,7 +818,13 @@ def flash_attn_with_kvcache(
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Does not support backward pass.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...
...
@@ -834,6 +843,8 @@ def flash_attn_with_kvcache(
1 1
If the row of the mask is all zero, the output will be zero.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
...
...
@@ -841,11 +852,18 @@ def flash_attn_with_kvcache(
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
...
...
@@ -865,6 +883,18 @@ def flash_attn_with_kvcache(
(
k_cache
.
shape
[
0
],),
cache_seqlens
,
dtype
=
torch
.
int32
,
device
=
k_cache
.
device
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
rotary_cos
,
rotary_sin
,
None
,
softmax_scale
,
causal
,
rotary_interleaved
,
num_splits
,
)
return
out
tests/models/test_gpt.py
View file @
ccbb14f3
...
...
@@ -280,7 +280,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
@
pytest
.
mark
.
parametrize
(
"seqlen,maxlen"
,
[(
10
,
20
),
(
30
,
150
),
(
3000
,
3400
),
(
14000
,
15000
)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"
block
"
])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"
contiguous
"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
...
...
tests/test_flash_attn.py
View file @
ccbb14f3
...
...
@@ -15,6 +15,7 @@ from flash_attn import (
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
...
...
@@ -1497,12 +1498,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [
Fals
e])
# @pytest.mark.parametrize("causal", [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1523,15 +1528,29 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
seqlen_new_eq_seqlen_q
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
seqlen_q
,
seqlen_k
,
d
,
rotary_fraction
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
,
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
if
not
new_kv
and
rotary_fraction
>
0.0
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
6
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -1545,12 +1564,42 @@ def test_flash_attn_kvcache(
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_new
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
causal
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
q_ro
=
rearrange
(
apply_rotary_emb
(
rearrange
(
q
,
"b s h d -> b 1 (s h) d"
),
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
,
),
"b 1 (s h) d -> b s h d"
,
s
=
seqlen_q
,
)
# q_ro = q
k_ro
=
apply_rotary_emb
(
k
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
v_cache_ref
=
v_cache
.
clone
()
...
...
@@ -1560,12 +1609,22 @@ def test_flash_attn_kvcache(
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
_ro
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
q
,
k_cache
,
v_cache
,
k
,
v
,
cos
,
sin
,
cache_seqlens
,
causal
=
causal
,
rotary_interleaved
=
rotary_interleaved
,
num_splits
=
num_splits
,
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
...
...
@@ -1577,10 +1636,10 @@ def test_flash_attn_kvcache(
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
q
_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
_
=
attention_ref
(
q
,
q
_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
...
...
@@ -1598,10 +1657,10 @@ def test_flash_attn_kvcache(
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
new_kv
:
assert
torch
.
equal
(
k_cache
,
k_cache_ref
)
assert
torch
.
allclose
(
k_cache
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment