Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f6ebba53
Unverified
Commit
f6ebba53
authored
Jun 10, 2025
by
fzyzcjy
Committed by
GitHub
Jun 09, 2025
Browse files
Support both approximate and exact expert distribution collection (#6964)
parent
6716b417
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
71 deletions
+101
-71
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+67
-43
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+19
-16
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+14
-11
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/managers/expert_distribution.py
View file @
f6ebba53
...
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
...
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
return
_DetailSinglePassGatherer
(
return
_DetailSinglePassGatherer
(
server_args
,
expert_location_metadata
,
rank
server_args
,
expert_location_metadata
,
rank
)
)
if
server_args
.
expert_distribution_recorder_mode
==
"stat_approx"
:
if
server_args
.
enable_deepep_moe
and
(
server_args
.
deepep_mode
==
"normal"
):
return
_DeepepNormalSinglePassGatherer
(
expert_location_metadata
,
rank
)
else
:
raise
NotImplementedError
if
server_args
.
enable_deepep_moe
:
if
server_args
.
enable_deepep_moe
:
if
server_args
.
deepep_mode
==
"normal"
:
if
server_args
.
deepep_mode
==
"normal"
:
return
_
DeepepNormal
SinglePassGatherer
(
expert_location_metadata
,
rank
)
return
_
SelectExperts
SinglePassGatherer
(
expert_location_metadata
,
rank
)
elif
server_args
.
deepep_mode
==
"low_latency"
:
elif
server_args
.
deepep_mode
==
"low_latency"
:
return
_DeepepLowLatencySinglePassGatherer
(
return
_DeepepLowLatencySinglePassGatherer
(
expert_location_metadata
,
rank
expert_location_metadata
,
rank
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
_SelectExpertsSinglePassGatherer
(
expert_location_metadata
,
rank
)
return
_SelectExpertsSinglePassGatherer
(
expert_location_metadata
,
rank
)
def
__init__
(
self
,
expert_location_metadata
:
"ExpertLocationMetadata"
,
rank
:
int
):
def
__init__
(
self
,
expert_location_metadata
:
"ExpertLocationMetadata"
,
rank
:
int
):
...
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
...
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
)
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
self
.
_topk_ids_of_layer
[
layer_idx
,
:
topk_ids
.
shape
[
0
],
:]
=
topk_ids
self
.
_topk_ids_of_layer
[
layer_idx
,
:
topk_ids
.
shape
[
0
],
:
topk_ids
.
shape
[
1
]]
=
(
topk_ids
)
def
on_deepep_dispatch_normal
(
def
on_deepep_dispatch_normal
(
self
,
self
,
...
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
...
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
)
class
_LayerBasedSinglePassGatherer
(
_SinglePassGatherer
):
class
_LayerBased
Cpu
SinglePassGatherer
(
_SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_objects_of_layer
=
{}
self
.
_objects_of_layer
=
{}
...
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
...
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
return
[
x
+
y
for
x
,
y
in
zip
(
a
,
b
,
strict
=
True
)]
return
[
x
+
y
for
x
,
y
in
zip
(
a
,
b
,
strict
=
True
)]
class
_SelectExpertsSinglePassGatherer
(
_LayerBasedSinglePassGatherer
):
class
_LayerBasedGpuSinglePassGatherer
(
_SinglePassGatherer
):
# pretty slow, but we will use the DeepEP Gatherer in production
def
__init__
(
self
,
*
args
,
enable_global_physical_experts
:
bool
,
**
kwargs
):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
super
().
__init__
(
*
args
,
**
kwargs
)
topk_ids_list
=
topk_ids
.
to
(
"cpu"
,
non_blocking
=
True
).
numpy
().
tolist
()
self
.
_enable_global_physical_experts
=
enable_global_physical_experts
torch
.
cuda
.
synchronize
()
self
.
_data
=
torch
.
zeros
(
(
global_physical_count
=
[
self
.
_expert_location_metadata
.
num_layers
,
0
(
]
*
self
.
_expert_location_metadata
.
num_physical_experts
self
.
_expert_location_metadata
.
num_physical_experts
for
token_record
in
topk_ids_list
:
if
enable_global_physical_experts
for
global_physical_expert_idx
in
token_record
:
else
self
.
_expert_location_metadata
.
num_local_physical_experts
global_physical_count
[
global_physical_expert_idx
]
+=
1
),
),
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
self
.
_on_layer_data
(
layer_idx
,
global_physical_count
)
def
reset
(
self
):
self
.
_data
[...]
=
0
def
collect
(
self
)
->
Dict
:
def
collect
(
self
)
->
Dict
:
global_physical_count
=
super
().
_collect_objects
(
if
self
.
_enable_global_physical_experts
:
pad_len
=
self
.
_expert_location_metadata
.
num_physical_experts
global_physical_count
=
self
.
_data
else
:
# Can optimize if bottleneck
global_physical_count
=
_convert_local_to_global_physical_count
(
self
.
_data
,
rank
=
self
.
_rank
,
num_local_physical_experts
=
self
.
_expert_location_metadata
.
num_local_physical_experts
,
num_physical_experts
=
self
.
_expert_location_metadata
.
num_physical_experts
,
)
)
return
dict
(
global_physical_count
=
global_physical_count
)
return
dict
(
global_physical_count
=
global_physical_count
)
class
_DeepepNormalSinglePassGatherer
(
_LayerBasedSinglePassGatherer
):
class
_SelectExpertsSinglePassGatherer
(
_LayerBasedGpuSinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
enable_global_physical_experts
=
True
)
# can optimize (e.g. fuse / compile)
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
topk_ids
=
topk_ids
.
flatten
()
mask
=
topk_ids
!=
-
1
self
.
_data
[
layer_idx
,
:].
scatter_add_
(
dim
=
0
,
index
=
topk_ids
.
masked_fill
(
~
mask
,
0
).
long
(),
src
=
mask
.
int
()
)
class
_DeepepNormalSinglePassGatherer
(
_LayerBasedCpuSinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def
on_deepep_dispatch_normal
(
def
on_deepep_dispatch_normal
(
self
,
self
,
layer_idx
:
int
,
layer_idx
:
int
,
...
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
...
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
return
dict
(
global_physical_count
=
global_physical_count
)
return
dict
(
global_physical_count
=
global_physical_count
)
class
_DeepepLowLatencySinglePassGatherer
(
_SinglePassGatherer
):
class
_DeepepLowLatencySinglePassGatherer
(
_
LayerBasedGpu
SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
,
enable_global_physical_experts
=
False
)
self
.
_data
=
torch
.
zeros
(
(
self
.
_expert_location_metadata
.
num_layers
,
self
.
_expert_location_metadata
.
num_local_physical_experts
,
),
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
def
on_deepep_dispatch_low_latency
(
def
on_deepep_dispatch_low_latency
(
self
,
layer_idx
:
int
,
local_physical_count_of_layer
:
torch
.
Tensor
self
,
layer_idx
:
int
,
local_physical_count_of_layer
:
torch
.
Tensor
...
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
...
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
# Most naive implementation, can optimize later
# Most naive implementation, can optimize later
self
.
_data
[
layer_idx
,
:]
+=
local_physical_count_of_layer
self
.
_data
[
layer_idx
,
:]
+=
local_physical_count_of_layer
def
reset
(
self
):
self
.
_data
[...]
=
0
def
collect
(
self
)
->
Dict
:
# Can optimize if bottleneck
global_physical_count
=
_convert_local_to_global_physical_count
(
self
.
_data
,
rank
=
self
.
_rank
,
num_local_physical_experts
=
self
.
_expert_location_metadata
.
num_local_physical_experts
,
num_physical_experts
=
self
.
_expert_location_metadata
.
num_physical_experts
,
)
return
dict
(
global_physical_count
=
global_physical_count
)
def
_convert_local_to_global_physical_count
(
def
_convert_local_to_global_physical_count
(
local_physical_count
:
torch
.
Tensor
,
local_physical_count
:
torch
.
Tensor
,
...
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
...
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
def
get_class
(
server_args
:
ServerArgs
)
->
Type
[
"_Accumulator"
]:
def
get_class
(
server_args
:
ServerArgs
)
->
Type
[
"_Accumulator"
]:
return
{
return
{
"stat"
:
_StatAccumulator
,
"stat"
:
_StatAccumulator
,
"stat_approx"
:
_StatAccumulator
,
"per_pass"
:
_DetailAccumulator
,
"per_pass"
:
_DetailAccumulator
,
"per_token"
:
_DetailAccumulator
,
"per_token"
:
_DetailAccumulator
,
}[
server_args
.
expert_distribution_recorder_mode
]
}[
server_args
.
expert_distribution_recorder_mode
]
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f6ebba53
...
@@ -460,6 +460,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -460,6 +460,9 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
f6ebba53
...
@@ -255,6 +255,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -255,6 +255,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits
=
state
.
pop
(
"router_logits"
)
router_logits
=
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
python/sglang/srt/server_args.py
View file @
f6ebba53
...
@@ -182,7 +182,7 @@ class ServerArgs:
...
@@ -182,7 +182,7 @@ class ServerArgs:
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
expert_distribution_recorder_mode
:
Optional
[
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"stat_approx"
,
"per_pass"
,
"per_token"
]
]
=
None
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
enable_expert_distribution_metrics
:
bool
=
False
enable_expert_distribution_metrics
:
bool
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment