Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
56b7fc6e
Commit
56b7fc6e
authored
Sep 13, 2023
by
Tri Dao
Browse files
Simplify the implementation of KVcache attn by appending KV first
parent
d0032700
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
105 deletions
+119
-105
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+51
-85
tests/test_flash_attn.py
tests/test_flash_attn.py
+68
-20
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
56b7fc6e
...
@@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_knew
=
binfo
.
k_offset
(
params
.
knew_batch_stride
,
params
.
knew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
const
index_t
row_offset_vnew
=
binfo
.
k_offset
(
params
.
vnew_batch_stride
,
params
.
vnew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
vnew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
vnew_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
make_stride
(
params
.
v_row_stride
,
_1
{}));
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor
gKnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
knew_ptr
)
+
row_offset_knew
-
binfo
.
seqlen_k_cache
*
params
.
knew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
knew_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor
gVnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
vnew_ptr
)
+
row_offset_vnew
-
binfo
.
seqlen_k_cache
*
params
.
vnew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
vnew_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQ
{});
typename
Kernel_traits
::
SmemLayoutQ
{});
...
@@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgKnew
=
gmem_thr_copy_QKV
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgVnew
=
gmem_thr_copy_QKV
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Prologue
// Prologue
if
constexpr
(
Append_KV
)
{
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const
index_t
row_offset_knew
=
binfo
.
k_offset
(
params
.
knew_batch_stride
,
params
.
knew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
const
index_t
row_offset_vnew
=
binfo
.
k_offset
(
params
.
vnew_batch_stride
,
params
.
vnew_row_stride
,
bidb
)
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
vnew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
vnew_head_stride
;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor
gKnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
knew_ptr
)
+
row_offset_knew
-
binfo
.
seqlen_k_cache
*
params
.
knew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
knew_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor
gVnew
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
vnew_ptr
)
+
row_offset_vnew
-
binfo
.
seqlen_k_cache
*
params
.
vnew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
vnew_row_stride
,
_1
{}));
Tensor
tKgKnew
=
gmem_thr_copy_QKV
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tVgVnew
=
gmem_thr_copy_QKV
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
for
(
int
n_block
=
n_block_max
-
1
;
n_block
>=
n_block_copy_min
;
n_block
--
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKgKnew
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVgVnew
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
}
__syncthreads
();
if
(
n_block_max
>
n_block_copy_min
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
n_block_max
-
n_block_copy_min
)
*
kBlockN
*
params
.
k_row_stride
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
n_block_max
-
n_block_copy_min
)
*
kBlockN
*
params
.
v_row_stride
;
}
}
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
...
@@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
Is_even_MN
,
Is_even_K
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
// flash::cp_async_wait<0>();
// flash::cp_async_wait<0>();
...
@@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (cute::thread0()) { print(tKgK); }
// if (cute::thread0()) { print(tKsK); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKsK
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
// __syncthreads();
// if (cute::thread0()) { print(tKgK); }
// __syncthreads();
}
// Advance gV
// Advance gV
if
(
masking_step
>
0
)
{
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
if
(
Append_KV
)
{
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
else
{
}
else
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
);
}
}
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// __syncthreads();
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVsV
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
if
(
n_block
>
n_block_min
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
// Advance gK
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
if
(
Append_KV
)
{
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
(
n_block
-
1
)
*
kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_s
);
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tKsK
,
tKgK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
// Advance gV
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
if
(
Append_KV
)
{
tVgVnew
.
data
()
=
tVgVnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
vnew_row_stride
));
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVgVnew
,
tVsV
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
flash
::
gemm
(
flash
::
gemm
(
...
@@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
constexpr
(
Append_KV
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if
(
bidh
%
params
.
h_h_k_ratio
==
0
&&
binfo
.
seqlen_k_cache
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
copy_w_min_idx
<
Is_even_K
>
(
tVsV
,
tVgV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
,
binfo
.
seqlen_k_cache
-
n_block
*
kBlockN
);
}
}
if
(
n_block
>
n_block_min
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
if
(
Append_KV
)
{
tKgKnew
.
data
()
=
tKgKnew
.
data
()
+
(
-
int
(
kBlockN
*
params
.
knew_row_stride
));
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy_2_sources
<
/*Is_2_sources=*/
Append_KV
,
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKgKnew
,
tKsK
,
tKVcKV
,
tKVpKV
,
0
,
binfo
.
seqlen_k_cache
-
(
n_block
-
1
)
*
kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
...
tests/test_flash_attn.py
View file @
56b7fc6e
...
@@ -149,8 +149,9 @@ def generate_qkv(
...
@@ -149,8 +149,9 @@ def generate_qkv(
)
)
def
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
def
construct_causal_mask
(
device
=
None
):
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
sk
=
(
...
@@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax(
...
@@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax(
causal_mask
=
construct_causal_mask
(
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
S
.
device
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
S
.
device
)
)
causal_mask
=
F
.
pad
(
causal_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
)
causal_mask
=
F
.
pad
(
causal_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
)
S_converted
.
masked_fill_
(
causal_mask
,
0.0
)
S_converted
.
masked_fill_
(
causal_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
# and some of those values aren't overwritten.
seqlen_q_og
=
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q_rounded
seqlen_q_og
=
(
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q_rounded
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
...
@@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output(
...
@@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output(
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
)
)
dropout_mask
=
S_dmask_converted
>=
0
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn_unnorm
=
S_dmask_converted
.
abs
()
...
@@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["m
q
a"])
# @pytest.mark.parametrize("mha_type", ["m
h
a"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [
Fals
e])
# @pytest.mark.parametrize("new_kv", [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
@@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
):
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
seqlen_new_eq_seqlen_q
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
pytest
.
skip
()
device
=
"cuda"
device
=
"cuda"
...
@@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
...
@@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_
q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
seqlen_
new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_
q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_
new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
else
:
k
,
v
=
None
,
None
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_q
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_new
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
k_cache_ref
=
k_cache
.
clone
()
...
@@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
...
@@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_q
)
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
...
@@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
...
@@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_q
if
new_kv
else
0
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_ref
,
_
=
attention_ref
(
out_pt
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
upcast
=
False
,
reorder_ops
=
True
)
)
out_pt
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
...
@@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
...
@@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
(
1024
,
1024
),
(
1024
,
1024
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
'
dropout_p
'
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"
dropout_p
"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
dtype
):
def
test_flash_attn_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
dtype
):
device
=
"cuda"
device
=
"cuda"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment