Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f6ebba53
"vscode:/vscode.git/clone" did not exist on "1a8e0c00a9de0234ea01fa09db4aca2e60f4ee0f"
Unverified
Commit
f6ebba53
authored
Jun 10, 2025
by
fzyzcjy
Committed by
GitHub
Jun 09, 2025
Browse files
Support both approximate and exact expert distribution collection (#6964)
parent
6716b417
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
71 deletions
+101
-71
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+67
-43
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+19
-16
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+14
-11
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/managers/expert_distribution.py
View file @
f6ebba53
...
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
...
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
return
_DetailSinglePassGatherer
(
return
_DetailSinglePassGatherer
(
server_args
,
expert_location_metadata
,
rank
server_args
,
expert_location_metadata
,
rank
)
)
if
server_args
.
expert_distribution_recorder_mode
==
"stat_approx"
:
if
server_args
.
enable_deepep_moe
and
(
server_args
.
deepep_mode
==
"normal"
):
return
_DeepepNormalSinglePassGatherer
(
expert_location_metadata
,
rank
)
else
:
raise
NotImplementedError
if
server_args
.
enable_deepep_moe
:
if
server_args
.
enable_deepep_moe
:
if
server_args
.
deepep_mode
==
"normal"
:
if
server_args
.
deepep_mode
==
"normal"
:
return
_
DeepepNormal
SinglePassGatherer
(
expert_location_metadata
,
rank
)
return
_
SelectExperts
SinglePassGatherer
(
expert_location_metadata
,
rank
)
elif
server_args
.
deepep_mode
==
"low_latency"
:
elif
server_args
.
deepep_mode
==
"low_latency"
:
return
_DeepepLowLatencySinglePassGatherer
(
return
_DeepepLowLatencySinglePassGatherer
(
expert_location_metadata
,
rank
expert_location_metadata
,
rank
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
_SelectExpertsSinglePassGatherer
(
expert_location_metadata
,
rank
)
return
_SelectExpertsSinglePassGatherer
(
expert_location_metadata
,
rank
)
def
__init__
(
self
,
expert_location_metadata
:
"ExpertLocationMetadata"
,
rank
:
int
):
def
__init__
(
self
,
expert_location_metadata
:
"ExpertLocationMetadata"
,
rank
:
int
):
...
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
...
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
)
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
self
.
_topk_ids_of_layer
[
layer_idx
,
:
topk_ids
.
shape
[
0
],
:]
=
topk_ids
self
.
_topk_ids_of_layer
[
layer_idx
,
:
topk_ids
.
shape
[
0
],
:
topk_ids
.
shape
[
1
]]
=
(
topk_ids
)
def
on_deepep_dispatch_normal
(
def
on_deepep_dispatch_normal
(
self
,
self
,
...
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
...
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
)
class
_LayerBasedSinglePassGatherer
(
_SinglePassGatherer
):
class
_LayerBased
Cpu
SinglePassGatherer
(
_SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_objects_of_layer
=
{}
self
.
_objects_of_layer
=
{}
...
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
...
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
return
[
x
+
y
for
x
,
y
in
zip
(
a
,
b
,
strict
=
True
)]
return
[
x
+
y
for
x
,
y
in
zip
(
a
,
b
,
strict
=
True
)]
class
_SelectExpertsSinglePassGatherer
(
_LayerBasedSinglePassGatherer
):
class
_LayerBasedGpuSinglePassGatherer
(
_SinglePassGatherer
):
# pretty slow, but we will use the DeepEP Gatherer in production
def
__init__
(
self
,
*
args
,
enable_global_physical_experts
:
bool
,
**
kwargs
):
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
super
().
__init__
(
*
args
,
**
kwargs
)
topk_ids_list
=
topk_ids
.
to
(
"cpu"
,
non_blocking
=
True
).
numpy
().
tolist
()
self
.
_enable_global_physical_experts
=
enable_global_physical_experts
torch
.
cuda
.
synchronize
()
self
.
_data
=
torch
.
zeros
(
(
global_physical_count
=
[
self
.
_expert_location_metadata
.
num_layers
,
0
(
]
*
self
.
_expert_location_metadata
.
num_physical_experts
self
.
_expert_location_metadata
.
num_physical_experts
for
token_record
in
topk_ids_list
:
if
enable_global_physical_experts
for
global_physical_expert_idx
in
token_record
:
else
self
.
_expert_location_metadata
.
num_local_physical_experts
global_physical_count
[
global_physical_expert_idx
]
+=
1
),
),
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
self
.
_on_layer_data
(
layer_idx
,
global_physical_count
)
def
reset
(
self
):
self
.
_data
[...]
=
0
def
collect
(
self
)
->
Dict
:
def
collect
(
self
)
->
Dict
:
global_physical_count
=
super
().
_collect_objects
(
if
self
.
_enable_global_physical_experts
:
pad_len
=
self
.
_expert_location_metadata
.
num_physical_experts
global_physical_count
=
self
.
_data
)
else
:
# Can optimize if bottleneck
global_physical_count
=
_convert_local_to_global_physical_count
(
self
.
_data
,
rank
=
self
.
_rank
,
num_local_physical_experts
=
self
.
_expert_location_metadata
.
num_local_physical_experts
,
num_physical_experts
=
self
.
_expert_location_metadata
.
num_physical_experts
,
)
return
dict
(
global_physical_count
=
global_physical_count
)
return
dict
(
global_physical_count
=
global_physical_count
)
class
_DeepepNormalSinglePassGatherer
(
_LayerBasedSinglePassGatherer
):
class
_SelectExpertsSinglePassGatherer
(
_LayerBasedGpuSinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
enable_global_physical_experts
=
True
)
# can optimize (e.g. fuse / compile)
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
topk_ids
=
topk_ids
.
flatten
()
mask
=
topk_ids
!=
-
1
self
.
_data
[
layer_idx
,
:].
scatter_add_
(
dim
=
0
,
index
=
topk_ids
.
masked_fill
(
~
mask
,
0
).
long
(),
src
=
mask
.
int
()
)
class
_DeepepNormalSinglePassGatherer
(
_LayerBasedCpuSinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def
on_deepep_dispatch_normal
(
def
on_deepep_dispatch_normal
(
self
,
self
,
layer_idx
:
int
,
layer_idx
:
int
,
...
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
...
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
return
dict
(
global_physical_count
=
global_physical_count
)
return
dict
(
global_physical_count
=
global_physical_count
)
class
_DeepepLowLatencySinglePassGatherer
(
_SinglePassGatherer
):
class
_DeepepLowLatencySinglePassGatherer
(
_
LayerBasedGpu
SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
,
enable_global_physical_experts
=
False
)
self
.
_data
=
torch
.
zeros
(
(
self
.
_expert_location_metadata
.
num_layers
,
self
.
_expert_location_metadata
.
num_local_physical_experts
,
),
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
def
on_deepep_dispatch_low_latency
(
def
on_deepep_dispatch_low_latency
(
self
,
layer_idx
:
int
,
local_physical_count_of_layer
:
torch
.
Tensor
self
,
layer_idx
:
int
,
local_physical_count_of_layer
:
torch
.
Tensor
...
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
...
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
# Most naive implementation, can optimize later
# Most naive implementation, can optimize later
self
.
_data
[
layer_idx
,
:]
+=
local_physical_count_of_layer
self
.
_data
[
layer_idx
,
:]
+=
local_physical_count_of_layer
def
reset
(
self
):
self
.
_data
[...]
=
0
def
collect
(
self
)
->
Dict
:
# Can optimize if bottleneck
global_physical_count
=
_convert_local_to_global_physical_count
(
self
.
_data
,
rank
=
self
.
_rank
,
num_local_physical_experts
=
self
.
_expert_location_metadata
.
num_local_physical_experts
,
num_physical_experts
=
self
.
_expert_location_metadata
.
num_physical_experts
,
)
return
dict
(
global_physical_count
=
global_physical_count
)
def
_convert_local_to_global_physical_count
(
def
_convert_local_to_global_physical_count
(
local_physical_count
:
torch
.
Tensor
,
local_physical_count
:
torch
.
Tensor
,
...
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
...
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
def
get_class
(
server_args
:
ServerArgs
)
->
Type
[
"_Accumulator"
]:
def
get_class
(
server_args
:
ServerArgs
)
->
Type
[
"_Accumulator"
]:
return
{
return
{
"stat"
:
_StatAccumulator
,
"stat"
:
_StatAccumulator
,
"stat_approx"
:
_StatAccumulator
,
"per_pass"
:
_DetailAccumulator
,
"per_pass"
:
_DetailAccumulator
,
"per_token"
:
_DetailAccumulator
,
"per_token"
:
_DetailAccumulator
,
}[
server_args
.
expert_distribution_recorder_mode
]
}[
server_args
.
expert_distribution_recorder_mode
]
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f6ebba53
...
@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
...
@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
hidden_states
=
hidden_states
,
self
.
layer_id
router_logits
=
router_logits
,
):
top_k
=
self
.
top_k
,
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
use_grouped_topk
=
True
,
hidden_states
=
hidden_states
,
renormalize
=
self
.
renormalize
,
router_logits
=
router_logits
,
topk_group
=
self
.
topk_group
,
top_k
=
self
.
top_k
,
num_expert_group
=
self
.
num_expert_group
,
use_grouped_topk
=
True
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
renormalize
=
self
.
renormalize
,
correction_bias
=
self
.
correction_bias
,
topk_group
=
self
.
topk_group
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_expert_group
=
self
.
num_expert_group
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
correction_bias
=
self
.
correction_bias
,
layer_id
=
self
.
layer_id
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
),
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
)
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
else
:
state
.
topk_idx_local
=
torch
.
full
(
state
.
topk_idx_local
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
f6ebba53
...
@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits
=
state
.
pop
(
"router_logits"
)
router_logits
=
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
hidden_states
=
hidden_states
,
self
.
layer_id
router_logits
=
router_logits
,
):
top_k
=
self
.
top_k
,
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
use_grouped_topk
=
False
,
hidden_states
=
hidden_states
,
renormalize
=
self
.
renormalize
,
router_logits
=
router_logits
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
top_k
=
self
.
top_k
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
use_grouped_topk
=
False
,
layer_id
=
self
.
layer_id
,
renormalize
=
self
.
renormalize
,
),
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
)
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
else
:
state
.
topk_idx_local
=
torch
.
full
(
state
.
topk_idx_local
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
...
...
python/sglang/srt/server_args.py
View file @
f6ebba53
...
@@ -182,7 +182,7 @@ class ServerArgs:
...
@@ -182,7 +182,7 @@ class ServerArgs:
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
expert_distribution_recorder_mode
:
Optional
[
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"stat_approx"
,
"per_pass"
,
"per_token"
]
]
=
None
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
enable_expert_distribution_metrics
:
bool
=
False
enable_expert_distribution_metrics
:
bool
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment