Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
38216cf0
Unverified
Commit
38216cf0
authored
Jul 16, 2025
by
Albert
Committed by
GitHub
Jul 15, 2025
Browse files
concurrently load weights of DeepseekV2ForCausalLM (#7943)
Signed-off-by:
Tianyu Zhou
<
albert.zty@antgroup.com
>
parent
4a883795
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
148 additions
and
127 deletions
+148
-127
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+148
-127
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
38216cf0
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
"""Inference-only DeepseekV2 model."""
import
concurrent.futures
import
logging
import
logging
import
os
import
os
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
...
@@ -2436,6 +2437,8 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2436,6 +2437,8 @@ class DeepseekV2ForCausalLM(nn.Module):
assert
self
.
num_fused_shared_experts
==
1
assert
self
.
num_fused_shared_experts
==
1
log_info_on_rank0
(
logger
,
"Shared experts fusion optimization enabled."
)
log_info_on_rank0
(
logger
,
"Shared experts fusion optimization enabled."
)
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
executor
:
futures
=
[]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
weight_names
=
[]
weight_names
=
[]
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
@@ -2496,7 +2499,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2496,7 +2499,9 @@ class DeepseekV2ForCausalLM(nn.Module):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
,
shard_id
)
)
break
break
else
:
else
:
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
...
@@ -2506,13 +2511,16 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2506,13 +2511,16 @@ class DeepseekV2ForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
param
,
loaded_weight
,
loaded_weight
,
name
,
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
)
)
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
...
@@ -2550,10 +2558,13 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2550,10 +2558,13 @@ class DeepseekV2ForCausalLM(nn.Module):
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
)
param_name
=
(
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
if
"q_a_proj"
in
name
else
name
.
replace
(
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
,
)
)
)
)
param
=
params_dict
[
param_name
]
param
=
params_dict
[
param_name
]
...
@@ -2561,7 +2572,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2561,7 +2572,9 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
fused_weight
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
fused_weight
)
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
else
:
...
@@ -2571,7 +2584,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2571,7 +2584,9 @@ class DeepseekV2ForCausalLM(nn.Module):
# modelopt attn kv scale is named differently
# modelopt attn kv scale is named differently
for
scale
in
[
"k_scale"
,
"v_scale"
]:
for
scale
in
[
"k_scale"
,
"v_scale"
]:
if
scale
in
name
:
if
scale
in
name
:
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
name
=
name
.
replace
(
f
"
{
scale
[
0
]
}
_proj"
,
"attn_mqa"
)
break
break
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
# modelopt ckpt contains not needed weights for MTP module:
# modelopt ckpt contains not needed weights for MTP module:
...
@@ -2583,7 +2598,13 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2583,7 +2598,13 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
loaded_weight
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
loaded_weight
)
)
# Wait for all tasks to complete and raise any exceptions.
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
future
.
result
()
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment