Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
88f5b19f
Unverified
Commit
88f5b19f
authored
Nov 19, 2025
by
Yongye Zhu
Committed by
GitHub
Nov 19, 2025
Browse files
[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)
Signed-off-by:
Yongye Zhu
<
zyy1102000@gmail.com
>
parent
613abb50
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
3 deletions
+17
-3
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+5
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+12
-2
No files found.
vllm/model_executor/layers/mla.py
View file @
88f5b19f
...
...
@@ -24,6 +24,7 @@ class MLAModules:
q_b_proj
:
torch
.
nn
.
Module
|
None
q_proj
:
torch
.
nn
.
Module
|
None
indexer
:
torch
.
nn
.
Module
|
None
indexer_rotary_emb
:
torch
.
nn
.
Module
|
None
is_sparse
:
bool
topk_indices_buffer
:
torch
.
Tensor
|
None
...
...
@@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self
.
rotary_emb
=
mla_modules
.
rotary_emb
self
.
o_proj
=
mla_modules
.
o_proj
self
.
indexer
=
mla_modules
.
indexer
self
.
indexer_rope_emb
=
mla_modules
.
indexer_rotary_emb
self
.
is_sparse
=
mla_modules
.
is_sparse
if
self
.
indexer
is
not
None
:
...
...
@@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
)
if
self
.
indexer
and
self
.
is_sparse
:
_topk_indices
=
self
.
indexer
(
hidden_states
,
q_c
,
positions
,
self
.
rotary_emb
)
_topk_indices
=
self
.
indexer
(
hidden_states
,
q_c
,
positions
,
self
.
indexer_rope_emb
)
attn_out
=
self
.
mla_attn
(
q
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
88f5b19f
...
...
@@ -837,8 +837,8 @@ class Indexer(nn.Module):
)
q_pe
,
k_pe
=
rotary_emb
(
positions
,
q_pe
,
k_pe
.
unsqueeze
(
1
))
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
1
),
k_nope
],
dim
=-
1
)
q
=
torch
.
cat
([
q_pe
.
squeeze
(
0
)
,
q_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
(
0
,
2
)
),
k_nope
],
dim
=-
1
)
# we only quant q here since k quant is fused with cache insertion
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
...
...
@@ -987,6 +987,14 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
is_v32
=
hasattr
(
config
,
"index_topk"
)
if
self
.
is_v32
:
self
.
indexer_rope_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
True
,
)
self
.
indexer
=
Indexer
(
vllm_config
,
config
,
...
...
@@ -998,6 +1006,7 @@ class DeepseekV2MLAAttention(nn.Module):
f
"
{
prefix
}
.indexer"
,
)
else
:
self
.
indexer_rope_emb
=
None
self
.
indexer
=
None
mla_modules
=
MLAModules
(
...
...
@@ -1015,6 +1024,7 @@ class DeepseekV2MLAAttention(nn.Module):
q_b_proj
=
self
.
q_b_proj
if
self
.
q_lora_rank
is
not
None
else
None
,
q_proj
=
self
.
q_proj
if
self
.
q_lora_rank
is
None
else
None
,
indexer
=
self
.
indexer
,
indexer_rotary_emb
=
self
.
indexer_rope_emb
,
is_sparse
=
self
.
is_v32
,
topk_indices_buffer
=
topk_indices_buffer
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment