Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5ed68ef
Unverified
Commit
f5ed68ef
authored
Oct 15, 2025
by
Yongye Zhu
Committed by
GitHub
Oct 15, 2025
Browse files
[Deepseek-V3.2][Kernel] Integrate cuda indexer k cache gather (#26456)
Signed-off-by:
Yongye Zhu
<
zyy1102000@gmail.com
>
parent
efdef57b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
68 deletions
+6
-68
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-68
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
f5ed68ef
...
@@ -75,7 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -75,7 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
cdiv
,
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.v1.attention.backends.mla.indexer
import
(
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerBackend
,
DeepseekV32IndexerBackend
,
...
@@ -483,69 +483,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
...
@@ -483,69 +483,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
return
DeepseekV32IndexerBackend
return
DeepseekV32IndexerBackend
@
torch
.
inference_mode
()
def
cp_gather_indexer_k_quant_cache
(
kv_cache
,
# [num_blocks, block_size, head_dim + 1]
dst_value
,
# [cu_seq_lens[-1], head_dim]
dst_scale
,
# [cu_seq_lens[-1], 4]
block_table
,
# [batch_size, num_blocks]
cu_seq_lens
,
# [batch_size + 1, ]
batch_size
,
):
num_blocks
,
block_size
,
_
=
kv_cache
.
shape
head_dim
=
dst_value
.
shape
[
-
1
]
kv_cache
=
kv_cache
.
view
(
num_blocks
,
-
1
)
expected_value
=
[]
expected_scale
=
[]
for
b
in
range
(
batch_size
):
s
=
cu_seq_lens
[
b
+
1
]
-
cu_seq_lens
[
b
]
if
s
==
0
:
continue
tot
=
cdiv
(
s
,
block_size
)
blocks
=
block_table
[
b
,
:
tot
]
value
=
[]
scale
=
[]
full_block
=
torch
.
arange
(
tot
-
1
,
device
=
kv_cache
.
device
,
dtype
=
torch
.
int32
)
non_remaining_value
=
kv_cache
[
blocks
[
full_block
],
:
block_size
*
head_dim
].
view
(
-
1
,
head_dim
)
non_remaining_scale
=
kv_cache
[
blocks
[
full_block
],
block_size
*
head_dim
:
].
view
(
-
1
,
4
)
remaining
=
s
-
(
tot
-
1
)
*
block_size
value
=
torch
.
cat
(
[
non_remaining_value
,
kv_cache
[
blocks
[
-
1
],
:
remaining
*
head_dim
].
view
(
-
1
,
head_dim
),
],
dim
=
0
,
)
scale
=
torch
.
cat
(
[
non_remaining_scale
,
kv_cache
[
blocks
[
-
1
],
block_size
*
head_dim
:
block_size
*
head_dim
+
remaining
*
4
,
].
view
(
-
1
,
4
),
],
dim
=
0
,
)
expected_value
.
append
(
value
)
expected_scale
.
append
(
scale
)
gather_value
=
torch
.
cat
(
expected_value
,
dim
=
0
).
view
(
-
1
,
head_dim
)
gather_scale
=
torch
.
cat
(
expected_scale
,
dim
=
0
).
view
(
-
1
,
4
)
gather_value
=
gather_value
.
view
(
torch
.
float8_e4m3fn
)
gather_scale
=
gather_scale
.
view
(
torch
.
float32
)
dst_value
.
copy_
(
gather_value
)
dst_scale
.
copy_
(
gather_scale
)
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
k_cache_prefix
:
str
,
...
@@ -605,19 +542,20 @@ def sparse_attn_indexer(
...
@@ -605,19 +542,20 @@ def sparse_attn_indexer(
dtype
=
torch
.
float8_e4m3fn
,
dtype
=
torch
.
float8_e4m3fn
,
)
)
k_scale
=
torch
.
empty
(
k_scale
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
1
],
device
=
k
.
device
,
dtype
=
torch
.
float32
[
chunk
.
total_seq_lens
,
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
,
)
)
cp_gather_indexer_k_quant_cache
(
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
kv_cache
,
k_fp8
,
k_fp8
,
k_scale
,
k_scale
,
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
chunk
.
num_reqs
,
)
)
logits
=
fp8_mqa_logits
(
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
),
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)
),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
chunk
.
cu_seqlen_ke
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment