Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a14654dd
"vscode:/vscode.git/clone" did not exist on "4e986a823ca47eb16f563d15a6fe4cc393a00715"
Unverified
Commit
a14654dd
authored
Apr 24, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 24, 2025
Browse files
Fix weight loading bug for Deepseek v3+nextn (#5684)
parent
5d93a950
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
5 deletions
+51
-5
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+51
-5
No files found.
python/sglang/srt/models/deepseek_nextn.py
View file @
a14654dd
...
...
@@ -242,6 +242,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
n_share_experts_fusion
,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
"shared_head.norm"
,
...
...
@@ -313,6 +319,46 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Handle fused_qkv_a_proj
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment