Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
38216cf0
Unverified
Commit
38216cf0
authored
Jul 16, 2025
by
Albert
Committed by
GitHub
Jul 15, 2025
Browse files
concurrently load weights of DeepseekV2ForCausalLM (#7943)
Signed-off-by:
Tianyu Zhou
<
albert.zty@antgroup.com
>
parent
4a883795
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
148 additions
and
127 deletions
+148
-127
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+148
-127
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
38216cf0
...
...
@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
import
concurrent.futures
import
logging
import
os
from
enum
import
IntEnum
,
auto
...
...
@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module):
assert
self
.
num_fused_shared_experts
==
1
log_info_on_rank0
(
logger
,
"Shared experts fusion optimization enabled."
)
params_dict
=
dict
(
self
.
named_parameters
())
weight_names
=
[]
for
name
,
loaded_weight
in
weights
:
if
self
.
num_fused_shared_experts
>
0
and
"mlp.shared_experts"
in
name
:
name
=
name
.
replace
(
"mlp.shared_experts"
,
f
"mlp.experts.
{
self
.
config
.
n_routed_experts
}
"
,
)
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
executor
:
futures
=
[]
params_dict
=
dict
(
self
.
named_parameters
())
weight_names
=
[]
for
name
,
loaded_weight
in
weights
:
if
self
.
num_fused_shared_experts
>
0
and
"mlp.shared_experts"
in
name
:
name
=
name
.
replace
(
"mlp.shared_experts"
,
f
"mlp.experts.
{
self
.
config
.
n_routed_experts
}
"
,
)
weight_names
.
append
(
name
)
weight_names
.
append
(
name
)
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
if
(
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
continue
else
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
if
(
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
continue
else
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
,
shard_id
)
)
break
else
:
# Skip load
in
g
ex
tra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
for
mapping
in
ex
pert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_
name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
cat_dim
=
0
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
cat_dim
=
1
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
param
=
params_dict
[
param_name
]
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
cat_dim
=
0
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
cat_dim
=
1
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
,
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
fused_weight
)
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
for
scale
in
[
"k_scale"
,
"v_scale"
]:
if
scale
in
name
:
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
break
if
name
not
in
params_dict
:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger
.
warning
(
f
"
{
name
}
not found in params_dict."
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
for
scale
in
[
"k_scale"
,
"v_scale"
]:
if
scale
in
name
:
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
break
if
name
not
in
params_dict
:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger
.
warning
(
f
"
{
name
}
not found in params_dict."
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
)
)
# Wait for all tasks to complete and raise any exceptions.
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
future
.
result
()
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment