Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ee5c86e
Unverified
Commit
5ee5c86e
authored
Feb 11, 2026
by
Kebe
Committed by
GitHub
Feb 10, 2026
Browse files
[Bugfix][DeepSeek-V3.2] fix fp8 kvcache type cast (#33884)
Signed-off-by:
Kebe
<
mail@kebe7jun.com
>
parent
b5dcb372
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
4 deletions
+16
-4
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+16
-4
No files found.
csrc/cache_kernels.cu
View file @
5ee5c86e
...
@@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
...
@@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
"src_cache and seq_lens must be on the same device"
);
"src_cache and seq_lens must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
workspace_starts
.
device
(),
TORCH_CHECK
(
src_cache
.
device
()
==
workspace_starts
.
device
(),
"src_cache and workspace_starts must be on the same device"
);
"src_cache and workspace_starts must be on the same device"
);
auto
dtype
=
src_cache
.
scalar_type
();
TORCH_CHECK
(
src_cache
.
dtype
()
==
torch
::
kUInt8
,
"src_cache must be uint8"
);
TORCH_CHECK
(
dtype
==
at
::
ScalarType
::
Byte
||
// uint8
dtype
==
at
::
ScalarType
::
Float8_e4m3fn
||
// fp8 e4m3
dtype
==
at
::
ScalarType
::
Float8_e5m2
,
// fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got "
,
src_cache
.
dtype
());
TORCH_CHECK
(
dst
.
dtype
()
==
torch
::
kBFloat16
,
"dst must be bfloat16"
);
TORCH_CHECK
(
dst
.
dtype
()
==
torch
::
kBFloat16
,
"dst must be bfloat16"
);
TORCH_CHECK
(
head_dim
==
576
,
"head_dim must be 576 for MLA"
);
TORCH_CHECK
(
head_dim
==
576
,
"head_dim must be 576 for MLA"
);
...
@@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache(
...
@@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache(
int64_t
cache_entry_stride
=
src_cache
.
stride
(
1
);
int64_t
cache_entry_stride
=
src_cache
.
stride
(
1
);
int64_t
dst_entry_stride
=
dst
.
stride
(
0
);
int64_t
dst_entry_stride
=
dst
.
stride
(
0
);
const
uint8_t
*
src_ptr
=
nullptr
;
if
(
dtype
==
at
::
ScalarType
::
Byte
)
{
src_ptr
=
src_cache
.
data_ptr
<
uint8_t
>
();
}
else
{
// float8_e4m3fn or float8_e5m2
src_ptr
=
reinterpret_cast
<
const
uint8_t
*>
(
src_cache
.
data_ptr
());
}
// Decide on the number of splits based on the batch size
// Decide on the number of splits based on the batch size
int
num_splits
=
batch_size
>
128
?
2
:
batch_size
>
64
?
4
:
16
;
int
num_splits
=
batch_size
>
128
?
2
:
batch_size
>
64
?
4
:
16
;
dim3
grid
(
batch_size
,
num_splits
);
dim3
grid
(
batch_size
,
num_splits
);
dim3
block
(
576
);
dim3
block
(
576
);
vllm
::
cp_gather_and_upconvert_fp8_kv_cache
<<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
cp_gather_and_upconvert_fp8_kv_cache
<<<
grid
,
block
,
0
,
stream
>>>
(
src_cache
.
data_ptr
<
uint8_t
>
(),
src_ptr
,
reinterpret_cast
<
__nv_bfloat16
*>
(
dst
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
dst
.
data_ptr
()),
block_table
.
data_ptr
<
int32_t
>
(),
seq_lens
.
data_ptr
<
int32_t
>
(),
block_table
.
data_ptr
<
int32_t
>
(),
seq_lens
.
data_ptr
<
int32_t
>
(),
workspace_starts
.
data_ptr
<
int32_t
>
(),
block_size
,
head_dim
,
workspace_starts
.
data_ptr
<
int32_t
>
(),
block_size
,
head_dim
,
block_table_stride
,
cache_block_stride
,
cache_entry_stride
,
block_table_stride
,
cache_block_stride
,
cache_entry_stride
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment