Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
72bfb0ba
Unverified
Commit
72bfb0ba
authored
May 19, 2025
by
fzyzcjy
Committed by
GitHub
May 18, 2025
Browse files
Refactor DeepSeek MoE layer to unify the two forward branches (#6325)
parent
15521495
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
49 deletions
+53
-49
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+53
-49
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
72bfb0ba
...
@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
...
@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
return
logits
return
logits
def
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
return
(
(
forward_mode
is
not
None
)
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
)
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
),
),
)
)
self
.
top_k
=
config
.
num_experts_per_tok
if
global_server_args_dict
[
"enable_deepep_moe"
]:
if
global_server_args_dict
[
"enable_deepep_moe"
]:
# TODO: we will support tp < ep in the future
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_tensor_model_parallel_world_size
()
self
.
ep_size
=
get_tensor_model_parallel_world_size
()
self
.
num_experts
=
config
.
n_routed_experts
self
.
num_experts
=
config
.
n_routed_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
self
.
renormalize
=
config
.
norm_topk_prob
self
.
topk_group
=
config
.
topk_group
self
.
topk_group
=
config
.
topk_group
self
.
num_expert_group
=
config
.
n_group
self
.
num_expert_group
=
config
.
n_group
...
@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
...
@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
@
property
def
_enable_deepep_moe
(
self
):
return
global_server_args_dict
[
"enable_deepep_moe"
]
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
return
self
.
forward_normal
(
hidden_states
)
forward_mode
,
hidden_states
else
:
):
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
else
:
hidden_states
=
hidden_states
,
router_logits
=
router_logits
router_logits
=
None
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_deepep
(
if
(
self
.
n_share_experts_fusion
==
0
)
and
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
(
not
self
.
_enable_deepep_moe
)
)
->
torch
.
Tensor
:
or
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
)
forward_mode
=
forward_batch
.
forward_mode
shared_output
=
None
if
(
forward_mode
is
not
None
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
):
):
# router_logits: (num_tokens, n_experts)
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
else
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
None
if
self
.
_enable_deepep_moe
and
(
router_logits
is
not
None
):
topk_weights
,
topk_idx
=
select_experts
(
topk_weights
,
topk_idx
=
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights
=
torch
.
empty
(
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
)
if
self
.
ep_size
>
1
:
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
(
hidden_states
,
hidden_states
,
...
@@ -357,6 +356,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -357,6 +356,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights
,
topk_weights
,
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
)
)
if
self
.
_enable_deepep_moe
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
...
@@ -368,25 +369,28 @@ class DeepseekV2MoE(nn.Module):
...
@@ -368,25 +369,28 @@ class DeepseekV2MoE(nn.Module):
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
)
)
if
self
.
ep_size
>
1
:
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
,
final_hidden_states
,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
forward_mode
,
forward_mode
,
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
return
final_hidden_states
if
(
not
self
.
_enable_deepep_moe
)
and
(
self
.
tp_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
def
_forward_shared_experts
(
self
,
hidden_states
):
return
final_hidden_states
if
self
.
n_share_experts_fusion
==
0
:
return
self
.
shared_experts
(
hidden_states
)
else
:
return
None
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment