Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8617f867
Unverified
Commit
8617f867
authored
Apr 03, 2026
by
Yongye Zhu
Committed by
GitHub
Apr 03, 2026
Browse files
[Bugfix] Fix DSV32 weight loading (#38870)
Signed-off-by:
Yongye Zhu
<
zyy1102000@gmail.com
>
parent
06fd9ffc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
27 deletions
+68
-27
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+13
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+55
-24
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
8617f867
...
@@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
...
@@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
model
=
DeepSeekMultiTokenPredictor
(
self
.
model
=
DeepSeekMultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
)
# Set MoE hyperparameters
# Set MoE hyperparameters
self
.
set_moe_parameters
()
self
.
set_moe_parameters
()
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
def
set_moe_parameters
(
self
):
def
set_moe_parameters
(
self
):
self
.
expert_weights
=
[]
self
.
expert_weights
=
[]
...
@@ -241,11 +246,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
...
@@ -241,11 +246,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"fused_qkv_a_proj"
,
"q_a_proj"
,
0
),
(
"fused_qkv_a_proj"
,
"q_a_proj"
,
0
),
(
"fused_qkv_a_proj"
,
"kv_a_proj_with_mqa"
,
1
),
(
"fused_qkv_a_proj"
,
"kv_a_proj_with_mqa"
,
1
),
# Fused indexer wk + weights_proj
(
"wk_weights_proj"
,
"wk"
,
0
),
(
"wk_weights_proj"
,
"weights_proj"
,
1
),
]
]
if
self
.
is_fp4_ckpt
:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping
=
[
(
"wk_weights_proj"
,
"wk"
,
0
),
(
"wk_weights_proj"
,
"weights_proj"
,
1
),
]
stacked_params_mapping
.
extend
(
indexer_fused_mapping
)
expert_params_mapping
=
SharedFusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
SharedFusedMoE
.
make_expert_params_mapping
(
self
,
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
8617f867
...
@@ -625,6 +625,11 @@ class Indexer(nn.Module):
...
@@ -625,6 +625,11 @@ class Indexer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self
.
topk_tokens
=
config
.
index_topk
self
.
topk_tokens
=
config
.
index_topk
self
.
n_head
=
config
.
index_n_heads
# 64
self
.
n_head
=
config
.
index_n_heads
# 64
...
@@ -639,18 +644,36 @@ class Indexer(nn.Module):
...
@@ -639,18 +644,36 @@ class Indexer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wq_b"
,
prefix
=
f
"
{
prefix
}
.wq_b"
,
)
)
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
if
self
.
is_fp4_ckpt
:
# weights_proj does not get quantized, so we run both with quant_config=None
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# wk may be upcasted from the default quant; experiments show fusion is always
# weights_proj does not get quantized,
# faster unless WK proj is in FP4, which is not the case for all known quants.
# so we run both with quant_config=None
self
.
wk_weights_proj
=
MergedColumnParallelLinear
(
# wk may be upcasted from the default quant;
hidden_size
,
# experiments show fusion is always faster unless WK proj is in FP4,
[
self
.
head_dim
,
self
.
n_head
],
# which is not the case for all known quants.
bias
=
False
,
self
.
wk_weights_proj
=
MergedColumnParallelLinear
(
quant_config
=
None
,
hidden_size
,
disable_tp
=
True
,
[
self
.
head_dim
,
self
.
n_head
],
prefix
=
f
"
{
prefix
}
.wk_weights_proj"
,
bias
=
False
,
)
quant_config
=
None
,
disable_tp
=
True
,
prefix
=
f
"
{
prefix
}
.wk_weights_proj"
,
)
else
:
self
.
wk
=
ReplicatedLinear
(
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wk"
,
)
self
.
weights_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
n_head
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.weights_proj"
,
)
self
.
k_norm
=
LayerNorm
(
self
.
head_dim
,
eps
=
1e-6
)
self
.
k_norm
=
LayerNorm
(
self
.
head_dim
,
eps
=
1e-6
)
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
...
@@ -691,11 +714,14 @@ class Indexer(nn.Module):
...
@@ -691,11 +714,14 @@ class Indexer(nn.Module):
q_pe
,
q_nope
=
torch
.
split
(
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
)
if
self
.
is_fp4_ckpt
:
# Fused wk + weights_proj: one GEMM, then split
# Fused wk + weights_proj: one GEMM, then split
kw
,
_
=
self
.
wk_weights_proj
(
hidden_states
)
kw
,
_
=
self
.
wk_weights_proj
(
hidden_states
)
k
=
kw
[:,
:
self
.
head_dim
]
k
=
kw
[:,
:
self
.
head_dim
]
weights_raw
=
kw
[:,
self
.
head_dim
:]
weights
=
kw
[:,
self
.
head_dim
:]
else
:
k
,
_
=
self
.
wk
(
hidden_states
)
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
k
=
self
.
k_norm
(
k
)
k
=
self
.
k_norm
(
k
)
k_pe
,
k_nope
=
torch
.
split
(
k_pe
,
k_nope
=
torch
.
split
(
...
@@ -726,7 +752,7 @@ class Indexer(nn.Module):
...
@@ -726,7 +752,7 @@ class Indexer(nn.Module):
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
weights
=
(
weights
=
(
weights
_raw
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
)
)
weights
=
weights
.
squeeze
(
-
1
)
weights
=
weights
.
squeeze
(
-
1
)
...
@@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM(
...
@@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM(
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
qk_nope_head_dim
=
getattr
(
config
,
"qk_nope_head_dim"
,
0
)
qk_nope_head_dim
=
getattr
(
config
,
"qk_nope_head_dim"
,
0
)
qk_rope_head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
0
)
qk_rope_head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
0
)
...
@@ -1439,12 +1469,13 @@ class DeepseekV2ForCausalLM(
...
@@ -1439,12 +1469,13 @@ class DeepseekV2ForCausalLM(
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
if
self
.
is_fp4_ckpt
:
indexer_fused_mapping
=
[
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
(
"wk_weights_proj"
,
"wk"
,
0
),
indexer_fused_mapping
=
[
(
"wk_weights_proj"
,
"weights_proj"
,
1
),
(
"wk_weights_proj"
,
"wk"
,
0
),
]
(
"wk_weights_proj"
,
"weights_proj"
,
1
),
stacked_params_mapping
.
extend
(
indexer_fused_mapping
)
]
stacked_params_mapping
.
extend
(
indexer_fused_mapping
)
if
self
.
use_mha
:
if
self
.
use_mha
:
stacked_params_mapping
.
extend
(
mha_params_mapping
)
stacked_params_mapping
.
extend
(
mha_params_mapping
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment