Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d93a950
Unverified
Commit
5d93a950
authored
Apr 24, 2025
by
Yuhong Guo
Committed by
GitHub
Apr 24, 2025
Browse files
[BugFix] Fix combination of MTP and `--n-share-experts-fusion`with R1 (#5707)
parent
c998d04b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
15 deletions
+68
-15
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+50
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+18
-14
No files found.
python/sglang/srt/models/deepseek_nextn.py
View file @
5d93a950
...
...
@@ -13,12 +13,14 @@
# ==============================================================================
"""Inference-only DeepSeek NextN Speculative Decoding."""
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
@@ -51,6 +53,9 @@ else:
from
vllm._custom_ops
import
awq_dequantize
logger
=
logging
.
getLogger
(
__name__
)
class
DeepseekModelNextN
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
->
None
:
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
determine_n_share_experts_fusion
(
"DeepseekV3ForCausalLMNextN"
)
self
.
model
=
DeepseekModelNextN
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
...
...
@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
self
.
n_share_experts_fusion
>
0
:
logger
.
info
(
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
None
or
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
else
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale_inv"
,
"gate_proj.weight"
,
"gate_proj.weight_scale_inv"
,
"up_proj.weight"
,
"up_proj.weight_scale_inv"
,
]
names_to_remove
=
[]
for
num_repeat
in
range
(
self
.
n_share_experts_fusion
):
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.0.mlp.shared_experts.
{
suffix
}
"
)
weights_list
.
append
(
(
f
"model.layers.0."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
num_repeat
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
...
...
@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
n_share_experts_fusion
,
)
nextn_layer_prefix
=
"model.layers.0"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
5d93a950
...
...
@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
determine_n_share_experts_fusion
()
self
.
model
=
DeepseekV2Model
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
dp_size
=
get_attention_dp_size
()
def
determine_n_share_experts_fusion
(
self
,
architecture
:
str
=
"DeepseekV3ForCausalLM"
):
self
.
n_share_experts_fusion
=
global_server_args_dict
[
"n_share_experts_fusion"
]
if
self
.
n_share_experts_fusion
>
0
:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if
(
self
.
config
.
architectures
[
0
]
!=
"DeepseekV3ForCausalLM"
self
.
config
.
architectures
[
0
]
!=
architecture
or
self
.
config
.
n_routed_experts
!=
256
):
self
.
n_share_experts_fusion
=
0
...
...
@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
elif
self
.
n_share_experts_fusion
==
0
:
if
(
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
9
,
0
)
and
self
.
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
and
self
.
config
.
architectures
[
0
]
==
architecture
and
self
.
config
.
n_routed_experts
==
256
and
(
not
global_server_args_dict
[
"enable_deepep_moe"
])
):
...
...
@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
)
self
.
model
=
DeepseekV2Model
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
dp_size
=
get_attention_dp_size
()
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment