Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
37c6e054
Commit
37c6e054
authored
Sep 04, 2023
by
Tri Dao
Browse files
Implement flash_attn_with_kvcache
parent
4976650f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
663 additions
and
108 deletions
+663
-108
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+192
-3
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+7
-2
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+16
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+150
-61
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+41
-31
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+69
-2
flash_attn/__init__.py
flash_attn/__init__.py
+1
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+72
-0
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+25
-9
tests/test_flash_attn.py
tests/test_flash_attn.py
+90
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
37c6e054
...
...
@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
TORCH_CHECK
(
p_dropout
<
1.
f
);
params
.
is_causal
=
is_causal
;
params
.
is_seqlens_k_cumulative
=
true
;
}
void
set_params_dgrad
(
Flash_bwd_params
&
params
,
...
...
@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
params
.
dsoftmax_sum
=
dsoftmax_sum_d
;
}
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
if
(
params
.
num_splits
<=
1
)
{
// If we don't set it num_splits == 0
if
(
params
.
num_splits
<=
1
&&
!
force_split_kernel
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_dispatch
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
...
...
@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
1
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
64
);
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
...
...
@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
return
{
dq
,
dk
,
dv
,
softmax_d
};
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
bool
is_causal
,
int
num_splits
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
vcache
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
kcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
vcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
kcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
vcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
const
auto
sizes
=
q
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
seqlen_q
=
sizes
[
1
];
const
int
num_heads
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
at
::
Tensor
q_padded
,
kcache_padded
,
vcache_padded
;
if
(
head_size_og
%
8
!=
0
)
{
q_padded
=
torch
::
nn
::
functional
::
pad
(
q
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
kcache_padded
=
torch
::
nn
::
functional
::
pad
(
kcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
vcache_padded
=
torch
::
nn
::
functional
::
pad
(
vcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
q_padded
=
q
;
kcache_padded
=
kcache
;
vcache_padded
=
vcache
;
}
at
::
Tensor
out
;
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
}
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
head_size
=
round_multiple
(
head_size_og
,
8
);
const
int
head_size_rounded
=
round_multiple
(
head_size
,
32
);
const
int
seqlen_q_rounded
=
round_multiple
(
seqlen_q
,
128
);
const
int
seqlen_k_rounded
=
round_multiple
(
seqlen_k
,
128
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
auto
opts
=
q
.
options
();
auto
softmax_lse
=
torch
::
empty
({
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
Flash_fwd_params
params
;
set_params_fprop
(
params
,
batch_size
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
num_heads
,
num_heads_k
,
head_size
,
head_size_rounded
,
q_padded
,
kcache_padded
,
vcache_padded
,
out
,
/*cu_seqlens_q_d=*/
nullptr
,
/*cu_seqlens_k_d=*/
nullptr
,
/*p_ptr=*/
nullptr
,
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
softmax_scale
,
is_causal
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
TORCH_CHECK
(
v_
.
has_value
(),
"If key is supplied, value must also be passed in"
);
TORCH_CHECK
(
seqlens_k_
.
has_value
(),
"If key is supplied, seqlens_k must also be passed in"
);
TORCH_CHECK
(
seqlen_q
<=
seqlen_k
,
"If key is supplied, it must have seqlen <= the seqlen of the KV cache"
);
k
=
k_
.
value
();
v
=
v_
.
value
();
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"Key must have the same dtype as query"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"Value must have the same dtype as query"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Key tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Value tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Key tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Value tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
k_padded
=
k
;
v_padded
=
v
;
}
params
.
knew_ptr
=
k_padded
.
data_ptr
();
params
.
vnew_ptr
=
v_padded
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
knew_batch_stride
=
k_padded
.
stride
(
0
);
params
.
vnew_batch_stride
=
v_padded
.
stride
(
0
);
params
.
knew_row_stride
=
k_padded
.
stride
(
-
3
);
params
.
vnew_row_stride
=
v_padded
.
stride
(
-
3
);
params
.
knew_head_stride
=
k_padded
.
stride
(
-
2
);
params
.
vnew_head_stride
=
v_padded
.
stride
(
-
2
);
}
if
(
seqlens_k_
.
has_value
())
{
auto
seqlens_k
=
seqlens_k_
.
value
();
TORCH_CHECK
(
seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"seqlens_k must have dtype int32"
);
TORCH_CHECK
(
seqlens_k
.
is_cuda
(),
"seqlens_k must be on CUDA device"
);
TORCH_CHECK
(
seqlens_k
.
is_contiguous
(),
"seqlens_k must be contiguous"
);
CHECK_SHAPE
(
seqlens_k
,
batch_size
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
}
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
(
params
.
knew_ptr
==
nullptr
?
0
:
seqlen_q
)
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
num_splits
;
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
}
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Only split kernel supports appending to KV cache
run_mha_fwd
(
params
,
stream
,
/*force_split_kernel=*/
k_
.
has_value
());
if
(
head_size_og
%
8
!=
0
)
{
out
=
out
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)});
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
if
(
k_
.
has_value
())
{
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
// but we don't expect to get this case in practice. This is just so that the code works for that case.
kcache
.
copy_
(
kcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
vcache
.
copy_
(
vcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
}
}
return
{
out
,
softmax_lse
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"FlashAttention"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"varlen_fwd"
,
&
mha_varlen_fwd
,
"Forward pass (variable length)"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"varlen_bwd"
,
&
mha_varlen_bwd
,
"Backward pass (variable length)"
);
m
.
def
(
"fwd_kvcache"
,
&
mha_fwd_kvcache
,
"Forward pass, with KV-cache"
);
}
csrc/flash_attn/src/block_info.h
View file @
37c6e054
...
...
@@ -14,9 +14,12 @@ struct BlockInfo {
template
<
typename
Params
>
__device__
BlockInfo
(
const
Params
&
params
,
const
int
bidb
)
:
sum_s_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
-
1
:
params
.
cu_seqlens_q
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
||
!
params
.
is_seqlens_k_cumulative
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
actual_seqlen_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
params
.
seqlen_q
:
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
)
,
actual_seqlen_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_q
))
{
}
...
...
@@ -33,6 +36,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
int
actual_seqlen_q
;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const
int
seqlen_k_cache
;
const
int
actual_seqlen_k
;
};
...
...
csrc/flash_attn/src/flash.h
View file @
37c6e054
...
...
@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params {
int
*
__restrict__
blockmask
;
// The K_new and V_new matrices.
void
*
__restrict__
knew_ptr
;
void
*
__restrict__
vnew_ptr
;
// The stride between rows of the Q, K and V matrices.
index_t
knew_batch_stride
;
index_t
vnew_batch_stride
;
index_t
knew_row_stride
;
index_t
vnew_row_stride
;
index_t
knew_head_stride
;
index_t
vnew_head_stride
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// uint32_t p_dropout_in_uint;
...
...
@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params {
bool
is_bf16
;
bool
is_causal
;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool
is_seqlens_k_cumulative
;
int
num_splits
;
// For split-KV version
};
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
37c6e054
...
...
@@ -617,7 +617,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -635,7 +635,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
using
GmemTiledCopyO
=
std
::
conditional_t
<
!
Split
,
typename
Kernel_traits
::
GmemTiledCopyOaccum
,
typename
Kernel_traits
::
GmemTiledCopyO
>
;
using
ElementO
=
std
::
conditional_t
<!
Split
,
Element
,
ElementAccum
>
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
...
...
@@ -649,19 +658,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// Otherwise we might read OOB elements from gK and gV,
// or get wrong results when we combine gOaccum from different blocks.
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
Accum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lseaccum
),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
O
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
Split
?
kHeadDim
:
params
.
o_row_stride
,
_1
{})
)
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
row_offset_lseaccum
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
accum
gmem_tiled_copy_Oaccum
;
GmemTiledCopyO
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_D
(
gOaccum
);
Tensor
tOrOaccum
=
make_tensor
<
Element
Accum
>
(
shape
(
tOgOaccum
));
Tensor
tOrOaccum
=
make_tensor
<
Element
O
>
(
shape
(
tOgOaccum
));
clear
(
tOrOaccum
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gOaccum
),
size
<
1
>
(
gOaccum
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
...
...
@@ -679,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgOaccum
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSEaccum
(
row
)
=
-
INFINITY
;
}
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSEaccum
(
row
)
=
Split
?
-
INFINITY
:
INFINITY
;
}
}
return
;
}
...
...
@@ -695,6 +706,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_knew
=
binfo
.
k_offset
(
params
.
knew_batch_stride
,
params
.
knew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
const
index_t
row_offset_vnew
=
binfo
.
k_offset
(
params
.
vnew_batch_stride
,
params
.
vnew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
vnew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
vnew_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
...
@@ -702,15 +717,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor
gKnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
knew_ptr
)
+
row_offset_knew
-
binfo
.
seqlen_k_cache
*
params
.
knew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
knew_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor
gVnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
vnew_ptr
)
+
row_offset_vnew
-
binfo
.
seqlen_k_cache
*
params
.
vnew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
vnew_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQ
{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
...
...
@@ -721,8 +747,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgKnew
=
gmem_thr_copy_QKV
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgVnew
=
gmem_thr_copy_QKV
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
...
@@ -787,32 +815,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
}
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if
(
Kernel_traits
::
Is_Q_in_regs
&&
!
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
1
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
// flash::cp_async_wait<0>();
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
// __syncthreads();
clear
(
acc_o
);
...
...
@@ -834,19 +849,37 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (cute::thread0()) { print(tKgK); }
// if (cute::thread0()) { print(tKsK); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKsK
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
// __syncthreads();
// if (cute::thread0()) { print(tKgK); }
// __syncthreads();
}
// Advance gV
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
if
(
Append_KV
)
{
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
}
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
...
...
@@ -869,19 +902,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVsV
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
if
(
n_block
>
n_block_min
)
{
// Advance gK
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
if
(
Append_KV
)
{
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
}
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
(
n_block
-
1
)
*
kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
//
TODO: when w
e have key_padding_mask we'll need to Check_inf
//
W
e have key_padding_mask
so
we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
...
...
@@ -905,22 +958,45 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKsK
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
if
(
Append_KV
)
{
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
}
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVsV
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
if
(
Append_KV
)
{
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
}
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
(
n_block
-
1
)
*
kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
...
...
@@ -942,49 +1018,60 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
lse
=
make_fragment_like
(
scores_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
inv_sum
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
// if (cute::thread0()) { print(lse); }
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
sOaccum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
Accum
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
Tensor
sOaccum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
O
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto
smem_tiled_copy_Oaccum
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomOaccum
{},
tiled_mma
);
using
SmemTiledCopyO
=
std
::
conditional_t
<
!
Split
,
typename
Kernel_traits
::
SmemCopyAtomO
,
typename
Kernel_traits
::
SmemCopyAtomOaccum
>
;
auto
smem_tiled_copy_Oaccum
=
make_tiled_copy_C
(
SmemTiledCopyO
{},
tiled_mma
);
auto
smem_thr_copy_Oaccum
=
smem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
taccOrOaccum
=
smem_thr_copy_Oaccum
.
retile_S
(
acc_o
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
acc_o
);
Tensor
taccOrOaccum
=
smem_thr_copy_Oaccum
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsOaccum
=
smem_thr_copy_Oaccum
.
partition_D
(
sOaccum
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
// sOaccum is larger than sQ, so we need to syncthreads here
// TODO: allocate enough smem for sOaccum
if
constexpr
(
Split
)
{
__syncthreads
();
}
cute
::
copy
(
smem_tiled_copy_Oaccum
,
taccOrOaccum
,
taccOsOaccum
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
Accum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
O
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lseaccum
),
make_stride
(
Split
?
kHeadDim
:
params
.
o_row_stride
,
_1
{})
)
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
row_offset_lseaccum
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
// if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
typename
Kernel_traits
::
GmemTiledCopyO
accum
gmem_tiled_copy_Oaccum
;
GmemTiledCopyO
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOsOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
sOaccum
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_D
(
gOaccum
);
__syncthreads
();
Tensor
tOrOaccum
=
make_tensor
<
Element
Accum
>
(
shape
(
tOgOaccum
));
Tensor
tOrOaccum
=
make_tensor
<
Element
O
>
(
shape
(
tOgOaccum
));
cute
::
copy
(
gmem_tiled_copy_Oaccum
,
tOsOaccum
,
tOrOaccum
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
...
...
@@ -1014,6 +1101,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -1039,16 +1128,16 @@ inline __device__ void compute_attn(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
z
/
params
.
h
;
const
int
bidb
=
Split
?
blockIdx
.
z
/
params
.
h
:
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
-
bidb
*
params
.
h
;
const
int
n_split_idx
=
blockIdx
.
y
;
const
int
num_n_splits
=
gridDim
.
y
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
37c6e054
...
...
@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
...
...
@@ -63,45 +63,55 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
,
params
.
b
*
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
// } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
if
(
params
.
num_splits
>
1
)
{
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
128
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
}
template
<
typename
T
,
int
Headdim
>
...
...
csrc/flash_attn/src/utils.h
View file @
37c6e054
...
...
@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
...
...
@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
template
<
bool
Is_2_sources
=
false
,
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_2_sources
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S0
,
Tensor
<
Engine0
,
Layout0
>
const
&
S1
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
row_idx_switch
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S0
)
==
Int
<
3
>
{}
&&
rank
(
S1
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S0
)
==
size
<
0
>
(
D
)
&&
size
<
0
>
(
S1
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S0
)
==
size
<
1
>
(
D
)
&&
size
<
1
>
(
S1
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S0
)
==
size
<
2
>
(
D
)
&&
size
<
2
>
(
S1
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S0
);
++
m
)
{
auto
&
S
=
!
Is_2_sources
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
row_idx_switch
?
S0
:
S1
;
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S0
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
\ No newline at end of file
flash_attn/__init__.py
View file @
37c6e054
...
...
@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
flash_attn/flash_attn_interface.py
View file @
37c6e054
...
...
@@ -5,6 +5,7 @@ from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
...
...
@@ -790,3 +791,74 @@ def flash_attn_varlen_func(
causal
,
return_attn_probs
,
)
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
cache_seqlens
=
None
,
softmax_scale
=
None
,
causal
=
False
,
num_splits
=
0
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Does not support backward pass.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
)
return
out
flash_attn/utils/generation.py
View file @
37c6e054
...
...
@@ -348,8 +348,14 @@ def decode_speculative(
)
def
sample_tokens
(
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
,
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
...
...
@@ -374,12 +380,18 @@ def decode_speculative(
sequences
=
[]
if
decoding
:
assert
seqlen
==
1
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
devic
e
,
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
+
inference_params
.
sequence_len_offset
,
"s -> b s"
,
b
=
batch_siz
e
,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.sequence_len_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else
:
position_ids
=
None
logits
=
logits_postprocess_fn
(
...
...
@@ -399,7 +411,11 @@ def decode_speculative(
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
,
)
)
inference_params
.
sequence_len_offset
+=
1
...
...
@@ -420,7 +436,7 @@ def decode_speculative(
sample_fn
=
sample_fn
,
last_token_logits
=
True
,
inference_params
=
inference_params_draft
,
cg
=
cg
cg
=
cg
,
)
if
debug
:
...
...
tests/test_flash_attn.py
View file @
37c6e054
...
...
@@ -11,6 +11,7 @@ from flash_attn import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
...
...
@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
6
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_q
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
v_cache_ref
=
v_cache
.
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_q
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_q
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
new_kv
:
assert
torch
.
equal
(
k_cache
,
k_cache_ref
)
assert
torch
.
equal
(
v_cache
,
v_cache_ref
)
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment