Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
38216cf0
Unverified
Commit
38216cf0
authored
Jul 16, 2025
by
Albert
Committed by
GitHub
Jul 15, 2025
Browse files
concurrently load weights of DeepseekV2ForCausalLM (#7943)
Signed-off-by:
Tianyu Zhou
<
albert.zty@antgroup.com
>
parent
4a883795
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
148 additions
and
127 deletions
+148
-127
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+148
-127
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
38216cf0
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
"""Inference-only DeepseekV2 model."""
import
concurrent.futures
import
logging
import
logging
import
os
import
os
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
...
@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module):
assert
self
.
num_fused_shared_experts
==
1
assert
self
.
num_fused_shared_experts
==
1
log_info_on_rank0
(
logger
,
"Shared experts fusion optimization enabled."
)
log_info_on_rank0
(
logger
,
"Shared experts fusion optimization enabled."
)
params_dict
=
dict
(
self
.
named_parameters
())
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
executor
:
weight_names
=
[]
futures
=
[]
for
name
,
loaded_weight
in
weights
:
params_dict
=
dict
(
self
.
named_parameters
())
if
self
.
num_fused_shared_experts
>
0
and
"mlp.shared_experts"
in
name
:
weight_names
=
[]
name
=
name
.
replace
(
for
name
,
loaded_weight
in
weights
:
"mlp.shared_experts"
,
if
self
.
num_fused_shared_experts
>
0
and
"mlp.shared_experts"
in
name
:
f
"mlp.experts.
{
self
.
config
.
n_routed_experts
}
"
,
name
=
name
.
replace
(
)
"mlp.shared_experts"
,
f
"mlp.experts.
{
self
.
config
.
n_routed_experts
}
"
,
)
weight_names
.
append
(
name
)
weight_names
.
append
(
name
)
if
not
is_nextn
:
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
name_list
=
name
.
split
(
"."
)
if
(
if
(
len
(
name_list
)
>=
3
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
):
continue
continue
else
:
else
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
continue
# Use shared head and embed weights from target model
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
continue
is_decoder
=
True
is_decoder
=
True
# For nextn specific weights
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
is_decoder
=
False
break
break
# For decoder layer weights
# For decoder layer weights
if
is_decoder
:
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Since we handle the experts below in expert_params_mapping,
# Skip non-stacked layers and experts (experts handled below).
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
futures
.
append
(
param
,
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
,
shard_id
)
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
)
break
break
else
:
else
:
# Skip load
in
g
ex
tra bias for GPTQ models.
for
mapping
in
ex
pert_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
continue
if
weight_name
not
in
name
:
if
fuse_qkv_a_proj
and
(
continue
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
name
=
name
.
replace
(
weight_name
,
param_
name
)
):
param
=
params_dict
[
name
]
cached_a_proj
[
name
]
=
loaded_weight
weight_loader
=
param
.
weight_loader
q_a_proj_name
=
(
futures
.
append
(
name
executor
.
submit
(
if
"q_a_proj"
in
name
weight_loader
,
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
param
,
)
loaded_weight
,
kv_a_proj_name
=
(
name
,
name
shard_id
=
shard_id
,
if
"kv_a_proj_with_mqa"
in
name
expert_id
=
expert_id
,
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
)
)
break
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
else
:
if
(
# Skip loading extra bias for GPTQ models.
q_a_proj_name
in
cached_a_proj
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
and
kv_a_proj_name
in
cached_a_proj
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
cached_a_proj
[
name
]
=
loaded_weight
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
q_a_proj_name
=
(
cat_dim
=
0
name
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
cat_dim
=
1
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
if
"q_a_proj"
in
name
else
name
.
replace
(
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
)
param
=
params_dict
[
param_name
]
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
cat_dim
=
0
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
cat_dim
=
1
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
,
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
fused_weight
)
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
for
scale
in
[
"k_scale"
,
"v_scale"
]:
if
scale
in
name
:
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
break
if
name
not
in
params_dict
:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger
.
warning
(
f
"
{
name
}
not found in params_dict."
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
fused_weight
)
futures
.
append
(
cached_a_proj
.
pop
(
q_a_proj_name
)
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
)
else
:
if
(
# Wait for all tasks to complete and raise any exceptions.
"k_scale"
in
name
or
"v_scale"
in
name
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
)
and
name
not
in
params_dict
:
future
.
result
()
# modelopt attn kv scale is named differently
for
scale
in
[
"k_scale"
,
"v_scale"
]:
if
scale
in
name
:
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
break
if
name
not
in
params_dict
:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger
.
warning
(
f
"
{
name
}
not found in params_dict."
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment