Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eefbf4a6
Unverified
Commit
eefbf4a6
authored
Aug 01, 2025
by
Wentao Ye
Committed by
GitHub
Aug 01, 2025
Browse files
[Perf] Optimize `reshape_and_cache_flash` CUDA Kernel (#22036)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
88faa466
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
225 additions
and
23 deletions
+225
-23
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
+156
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+69
-23
No files found.
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
0 → 100644
View file @
eefbf4a6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random_flash
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
kv_cache_layout
:
str
,
num_iters
:
int
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random_flash
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
cache_layout
=
kv_cache_layout
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
layout
in
[
"NHD"
,
"HND"
]:
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
kv_cache_layout
=
layout
,
num_iters
=
args
.
iters
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
layout
,
f
"
{
lat
*
1e6
:.
3
f
}
"
])
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"layout"
,
"latency (µs)"
]))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
512
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
100
)
args
=
parser
.
parse_args
()
main
(
args
)
csrc/cache_kernels.cu
View file @
eefbf4a6
...
...
@@ -5,6 +5,7 @@
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
...
...
@@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
}
}
// Used by vectorization_utils to copy/convert one element
template
<
typename
OutT
,
typename
InT
,
Fp8KVCacheDataType
kv_dt
>
struct
CopyWithScaleOp
{
float
scale
;
__device__
__forceinline__
void
operator
()(
OutT
&
dst
,
const
InT
src
)
const
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
dst
=
static_cast
<
OutT
>
(
src
);
}
else
{
dst
=
fp8
::
scaled_convert
<
OutT
,
InT
,
kv_dt
>
(
src
,
scale
);
}
}
};
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
cache_t
*
__restrict__
key_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
cache_t
*
__restrict__
value_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
cache_t
*
__restrict__
key_cache
,
// NHD or HND, shape see comments below
cache_t
*
__restrict__
value_cache
,
// same above
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int64_t
block_stride
,
const
int64_t
page_stride
,
const
int64_t
head_stride
,
const
int64_t
key_stride
,
...
...
@@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
block_offset
*
page_stride
+
head_idx
*
head_stride
+
head_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_value_idx
]
=
tgt_key
;
value_cache
[
tgt_key_value_idx
]
=
tgt_value
;
}
else
{
key_cache
[
tgt_key_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
*
k_scale
);
value_cache
[
tgt_key_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
*
v_scale
);
const
int
n_elems
=
num_heads
*
head_size
;
// pointers to the beginning of the source row for this token.
const
scalar_t
*
__restrict__
key_src
=
key
+
token_idx
*
key_stride
;
const
scalar_t
*
__restrict__
value_src
=
value
+
token_idx
*
value_stride
;
// find the start position inside the kv-cache for this token.
cache_t
*
__restrict__
key_dst
=
key_cache
+
block_idx
*
block_stride
+
block_offset
*
page_stride
;
cache_t
*
__restrict__
value_dst
=
value_cache
+
block_idx
*
block_stride
+
block_offset
*
page_stride
;
// this is true for the NHD layout where `head_stride == head_size`
const
bool
is_contiguous_heads
=
(
head_stride
==
head_size
);
float
k_scale_val
=
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
?
0.
f
:
*
k_scale
;
float
v_scale_val
=
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
?
0.
f
:
*
v_scale
;
constexpr
int
VEC_SIZE
=
(
sizeof
(
scalar_t
)
==
2
)
?
8
:
4
;
CopyWithScaleOp
<
cache_t
,
scalar_t
,
kv_dt
>
k_op
{
k_scale_val
};
CopyWithScaleOp
<
cache_t
,
scalar_t
,
kv_dt
>
v_op
{
v_scale_val
};
if
(
is_contiguous_heads
)
{
// NHD layout
// kv cache: [num_blocks, block_size, num_heads, head_size]
vectorize_with_alignment
<
VEC_SIZE
>
(
key_src
,
key_dst
,
n_elems
,
threadIdx
.
x
,
blockDim
.
x
,
k_op
);
vectorize_with_alignment
<
VEC_SIZE
>
(
value_src
,
value_dst
,
n_elems
,
threadIdx
.
x
,
blockDim
.
x
,
v_op
);
}
else
{
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const
int
lane
=
threadIdx
.
x
&
31
;
// 0..31 within warp
const
int
warp_id
=
threadIdx
.
x
>>
5
;
// warp index within block
const
int
warps_per_block
=
blockDim
.
x
>>
5
;
for
(
int
head
=
warp_id
;
head
<
num_heads
;
head
+=
warps_per_block
)
{
const
scalar_t
*
__restrict__
k_src_h
=
key_src
+
head
*
head_size
;
const
scalar_t
*
__restrict__
v_src_h
=
value_src
+
head
*
head_size
;
cache_t
*
__restrict__
k_dst_h
=
key_dst
+
static_cast
<
int64_t
>
(
head
)
*
head_stride
;
cache_t
*
__restrict__
v_dst_h
=
value_dst
+
static_cast
<
int64_t
>
(
head
)
*
head_stride
;
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment
<
VEC_SIZE
>
(
k_src_h
,
k_dst_h
,
head_size
,
lane
,
32
,
k_op
);
vectorize_with_alignment
<
VEC_SIZE
>
(
v_src_h
,
v_dst_h
,
head_size
,
lane
,
32
,
v_op
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment