Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
37c6e054
Commit
37c6e054
authored
Sep 04, 2023
by
Tri Dao
Browse files
Implement flash_attn_with_kvcache
parent
4976650f
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
663 additions
and
108 deletions
+663
-108
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+192
-3
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+7
-2
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+16
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+150
-61
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+41
-31
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+69
-2
flash_attn/__init__.py
flash_attn/__init__.py
+1
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+72
-0
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+25
-9
tests/test_flash_attn.py
tests/test_flash_attn.py
+90
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
37c6e054
...
...
@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
TORCH_CHECK
(
p_dropout
<
1.
f
);
params
.
is_causal
=
is_causal
;
params
.
is_seqlens_k_cumulative
=
true
;
}
void
set_params_dgrad
(
Flash_bwd_params
&
params
,
...
...
@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
params
.
dsoftmax_sum
=
dsoftmax_sum_d
;
}
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
if
(
params
.
num_splits
<=
1
)
{
// If we don't set it num_splits == 0
if
(
params
.
num_splits
<=
1
&&
!
force_split_kernel
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_dispatch
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
...
...
@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
1
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
64
);
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
...
...
@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
return
{
dq
,
dk
,
dv
,
softmax_d
};
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
bool
is_causal
,
int
num_splits
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
vcache
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
kcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
vcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
kcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
vcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
const
auto
sizes
=
q
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
seqlen_q
=
sizes
[
1
];
const
int
num_heads
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
at
::
Tensor
q_padded
,
kcache_padded
,
vcache_padded
;
if
(
head_size_og
%
8
!=
0
)
{
q_padded
=
torch
::
nn
::
functional
::
pad
(
q
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
kcache_padded
=
torch
::
nn
::
functional
::
pad
(
kcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
vcache_padded
=
torch
::
nn
::
functional
::
pad
(
vcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
q_padded
=
q
;
kcache_padded
=
kcache
;
vcache_padded
=
vcache
;
}
at
::
Tensor
out
;
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
}
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
head_size
=
round_multiple
(
head_size_og
,
8
);
const
int
head_size_rounded
=
round_multiple
(
head_size
,
32
);
const
int
seqlen_q_rounded
=
round_multiple
(
seqlen_q
,
128
);
const
int
seqlen_k_rounded
=
round_multiple
(
seqlen_k
,
128
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
auto
opts
=
q
.
options
();
auto
softmax_lse
=
torch
::
empty
({
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
Flash_fwd_params
params
;
set_params_fprop
(
params
,
batch_size
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
num_heads
,
num_heads_k
,
head_size
,
head_size_rounded
,
q_padded
,
kcache_padded
,
vcache_padded
,
out
,
/*cu_seqlens_q_d=*/
nullptr
,
/*cu_seqlens_k_d=*/
nullptr
,
/*p_ptr=*/
nullptr
,
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
softmax_scale
,
is_causal
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
TORCH_CHECK
(
v_
.
has_value
(),
"If key is supplied, value must also be passed in"
);
TORCH_CHECK
(
seqlens_k_
.
has_value
(),
"If key is supplied, seqlens_k must also be passed in"
);
TORCH_CHECK
(
seqlen_q
<=
seqlen_k
,
"If key is supplied, it must have seqlen <= the seqlen of the KV cache"
);
k
=
k_
.
value
();
v
=
v_
.
value
();
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"Key must have the same dtype as query"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"Value must have the same dtype as query"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Key tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Value tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Key tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Value tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
k_padded
=
k
;
v_padded
=
v
;
}
params
.
knew_ptr
=
k_padded
.
data_ptr
();
params
.
vnew_ptr
=
v_padded
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
knew_batch_stride
=
k_padded
.
stride
(
0
);
params
.
vnew_batch_stride
=
v_padded
.
stride
(
0
);
params
.
knew_row_stride
=
k_padded
.
stride
(
-
3
);
params
.
vnew_row_stride
=
v_padded
.
stride
(
-
3
);
params
.
knew_head_stride
=
k_padded
.
stride
(
-
2
);
params
.
vnew_head_stride
=
v_padded
.
stride
(
-
2
);
}
if
(
seqlens_k_
.
has_value
())
{
auto
seqlens_k
=
seqlens_k_
.
value
();
TORCH_CHECK
(
seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"seqlens_k must have dtype int32"
);
TORCH_CHECK
(
seqlens_k
.
is_cuda
(),
"seqlens_k must be on CUDA device"
);
TORCH_CHECK
(
seqlens_k
.
is_contiguous
(),
"seqlens_k must be contiguous"
);
CHECK_SHAPE
(
seqlens_k
,
batch_size
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
}
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
(
params
.
knew_ptr
==
nullptr
?
0
:
seqlen_q
)
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
num_splits
;
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
}
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Only split kernel supports appending to KV cache
run_mha_fwd
(
params
,
stream
,
/*force_split_kernel=*/
k_
.
has_value
());
if
(
head_size_og
%
8
!=
0
)
{
out
=
out
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)});
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
if
(
k_
.
has_value
())
{
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
// but we don't expect to get this case in practice. This is just so that the code works for that case.
kcache
.
copy_
(
kcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
vcache
.
copy_
(
vcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
}
}
return
{
out
,
softmax_lse
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"FlashAttention"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"varlen_fwd"
,
&
mha_varlen_fwd
,
"Forward pass (variable length)"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"varlen_bwd"
,
&
mha_varlen_bwd
,
"Backward pass (variable length)"
);
m
.
def
(
"fwd_kvcache"
,
&
mha_fwd_kvcache
,
"Forward pass, with KV-cache"
);
}
csrc/flash_attn/src/block_info.h
View file @
37c6e054
...
...
@@ -14,9 +14,12 @@ struct BlockInfo {
template
<
typename
Params
>
__device__
BlockInfo
(
const
Params
&
params
,
const
int
bidb
)
:
sum_s_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
-
1
:
params
.
cu_seqlens_q
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
||
!
params
.
is_seqlens_k_cumulative
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
actual_seqlen_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
params
.
seqlen_q
:
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
)
,
actual_seqlen_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_q
))
{
}
...
...
@@ -33,6 +36,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
int
actual_seqlen_q
;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const
int
seqlen_k_cache
;
const
int
actual_seqlen_k
;
};
...
...
csrc/flash_attn/src/flash.h
View file @
37c6e054
...
...
@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params {
int
*
__restrict__
blockmask
;
// The K_new and V_new matrices.
void
*
__restrict__
knew_ptr
;
void
*
__restrict__
vnew_ptr
;
// The stride between rows of the Q, K and V matrices.
index_t
knew_batch_stride
;
index_t
vnew_batch_stride
;
index_t
knew_row_stride
;
index_t
vnew_row_stride
;
index_t
knew_head_stride
;
index_t
vnew_head_stride
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// uint32_t p_dropout_in_uint;
...
...
@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params {
bool
is_bf16
;
bool
is_causal
;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool
is_seqlens_k_cumulative
;
int
num_splits
;
// For split-KV version
};
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
37c6e054
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
37c6e054
...
...
@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
...
...
@@ -63,45 +63,55 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
,
params
.
b
*
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
// } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
if
(
params
.
num_splits
>
1
)
{
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
128
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
}
template
<
typename
T
,
int
Headdim
>
...
...
csrc/flash_attn/src/utils.h
View file @
37c6e054
...
...
@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
...
...
@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
template
<
bool
Is_2_sources
=
false
,
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_2_sources
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S0
,
Tensor
<
Engine0
,
Layout0
>
const
&
S1
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
row_idx_switch
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S0
)
==
Int
<
3
>
{}
&&
rank
(
S1
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S0
)
==
size
<
0
>
(
D
)
&&
size
<
0
>
(
S1
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S0
)
==
size
<
1
>
(
D
)
&&
size
<
1
>
(
S1
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S0
)
==
size
<
2
>
(
D
)
&&
size
<
2
>
(
S1
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S0
);
++
m
)
{
auto
&
S
=
!
Is_2_sources
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
row_idx_switch
?
S0
:
S1
;
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S0
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
\ No newline at end of file
flash_attn/__init__.py
View file @
37c6e054
...
...
@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
flash_attn/flash_attn_interface.py
View file @
37c6e054
...
...
@@ -5,6 +5,7 @@ from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
...
...
@@ -790,3 +791,74 @@ def flash_attn_varlen_func(
causal
,
return_attn_probs
,
)
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
cache_seqlens
=
None
,
softmax_scale
=
None
,
causal
=
False
,
num_splits
=
0
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Does not support backward pass.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
)
return
out
flash_attn/utils/generation.py
View file @
37c6e054
...
...
@@ -348,8 +348,14 @@ def decode_speculative(
)
def
sample_tokens
(
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
,
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
...
...
@@ -374,12 +380,18 @@ def decode_speculative(
sequences
=
[]
if
decoding
:
assert
seqlen
==
1
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
devic
e
,
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
+
inference_params
.
sequence_len_offset
,
"s -> b s"
,
b
=
batch_siz
e
,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.sequence_len_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else
:
position_ids
=
None
logits
=
logits_postprocess_fn
(
...
...
@@ -399,7 +411,11 @@ def decode_speculative(
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
,
)
)
inference_params
.
sequence_len_offset
+=
1
...
...
@@ -420,7 +436,7 @@ def decode_speculative(
sample_fn
=
sample_fn
,
last_token_logits
=
True
,
inference_params
=
inference_params_draft
,
cg
=
cg
cg
=
cg
,
)
if
debug
:
...
...
tests/test_flash_attn.py
View file @
37c6e054
...
...
@@ -11,6 +11,7 @@ from flash_attn import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
...
...
@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
6
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_q
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
v_cache_ref
=
v_cache
.
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_q
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_q
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
new_kv
:
assert
torch
.
equal
(
k_cache
,
k_cache_ref
)
assert
torch
.
equal
(
v_cache
,
v_cache_ref
)
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment