Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9effeb5b
Unverified
Commit
9effeb5b
authored
Jul 29, 2025
by
Cheng Wan
Committed by
GitHub
Jul 29, 2025
Browse files
Support EPLB in FusedMoE (#8448)
parent
1992ef9b
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
107 additions
and
11 deletions
+107
-11
python/sglang/srt/eplb/expert_distribution.py
python/sglang/srt/eplb/expert_distribution.py
+5
-0
python/sglang/srt/eplb/expert_location.py
python/sglang/srt/eplb/expert_location.py
+17
-6
python/sglang/srt/eplb/expert_location_dispatch.py
python/sglang/srt/eplb/expert_location_dispatch.py
+1
-0
python/sglang/srt/eplb/expert_location_updater.py
python/sglang/srt/eplb/expert_location_updater.py
+2
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+16
-3
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+44
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-0
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+3
-1
python/sglang/srt/models/granitemoe.py
python/sglang/srt/models/granitemoe.py
+3
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-0
python/sglang/srt/models/hunyuan.py
python/sglang/srt/models/hunyuan.py
+1
-0
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+3
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+3
-0
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+3
-0
python/sglang/srt/models/phimoe.py
python/sglang/srt/models/phimoe.py
+1
-0
No files found.
python/sglang/srt/eplb/expert_distribution.py
View file @
9effeb5b
...
...
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
rank
:
int
,
):
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
assert
(
expert_location_metadata
is
not
None
),
"ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return
_ExpertDistributionRecorderReal
(
server_args
,
expert_location_metadata
,
rank
)
...
...
python/sglang/srt/eplb/expert_location.py
View file @
9effeb5b
...
...
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Trivial location - logical expert i corresponds to physical expert i"""
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
num_physical_experts
=
common
[
"num_physical_experts"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_layers
=
model_config_for_expert_location
.
num_layers
...
...
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
physical_to_logical_map
,
...
...
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
logical_count
=
logical_count
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_groups
=
model_config_for_expert_location
.
num_groups
...
...
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
)
if
model_config_for_expert_location
is
None
:
return
None
num_physical_experts
=
(
model_config_for_expert_location
.
num_logical_experts
+
server_args
.
ep_num_redundant_experts
...
...
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts
:
int
num_groups
:
Optional
[
int
]
=
None
@
staticmethod
def
init_dummy
():
return
ModelConfigForExpertLocation
(
num_layers
=
1
,
num_logical_experts
=
1
)
@
staticmethod
def
from_model_config
(
model_config
:
ModelConfig
):
model_class
,
_
=
get_model_architecture
(
model_config
)
...
...
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
model_config
.
hf_config
)
else
:
return
ModelConfigForExpertLocation
.
init_dummy
()
return
None
def
compute_initial_expert_location_metadata
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
)
->
ExpertLocationMetadata
:
)
->
Optional
[
ExpertLocationMetadata
]
:
data
=
server_args
.
init_expert_location
if
data
==
"trivial"
:
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
...
...
python/sglang/srt/eplb/expert_location_dispatch.py
View file @
9effeb5b
...
...
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def
init_new
(
cls
,
layer_id
:
int
):
ep_dispatch_algorithm
=
global_server_args_dict
[
"ep_dispatch_algorithm"
]
expert_location_metadata
=
get_global_expert_location_metadata
()
assert
expert_location_metadata
is
not
None
if
ep_dispatch_algorithm
is
None
:
return
None
...
...
python/sglang/srt/eplb/expert_location_updater.py
View file @
9effeb5b
...
...
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
torch
.
cuda
.
empty_cache
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
assert
old_expert_location_metadata
is
not
None
_update_expert_weights
(
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
old_expert_location_metadata
=
old_expert_location_metadata
,
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
9effeb5b
...
...
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
...
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
top_k
=
top_k
,
num_fused_shared_experts
=
num_fused_shared_experts
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
...
...
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
physical_expert_ids
=
(
get_global_expert_location_metadata
().
logical_to_all_physical
(
self
.
layer_id
,
expert_id
global_expert_location_metadata
=
get_global_expert_location_metadata
()
if
global_expert_location_metadata
is
None
:
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
return
physical_expert_ids
=
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
...
...
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
...
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
num_fused_shared_experts
=
num_fused_shared_experts
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
9effeb5b
...
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
top_k
:
Optional
[
int
]
=
None
,
layer_id
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
...
...
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
layer_id
=
layer_id
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
tp_size
=
(
...
...
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map
=
None
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
...
...
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
global_expert_location_metadata
=
get_global_expert_location_metadata
()
if
global_expert_location_metadata
is
None
:
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
return
if
expert_id
>=
self
.
num_experts
-
self
.
num_fused_shared_experts
:
# This is a shared expert.
physical_expert_ids
=
[
expert_id
]
else
:
physical_expert_ids
=
(
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
)
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
physical_expert_id
,
)
def
_weight_loader_physical
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
return
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9effeb5b
...
...
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
...
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
9effeb5b
...
...
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
...
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
global_server_args_dict
[
"enable_deepep_moe"
]
or
global_server_args_dict
[
"enable_ep_moe"
]
):
disable_reason
=
"Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
disable_reason
=
"Deepseek
and
GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/granitemoe.py
View file @
9effeb5b
...
...
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
...
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
quant_config
=
quant_config
,
...
...
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
)
...
...
python/sglang/srt/models/grok.py
View file @
9effeb5b
...
...
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
...
...
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
...
...
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
)
self
.
block_sparse_moe
=
Grok1MoE
(
config
=
config
,
layer_id
=
layer_id
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
...
...
python/sglang/srt/models/hunyuan.py
View file @
9effeb5b
...
...
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
)
...
...
python/sglang/srt/models/llama4.py
View file @
9effeb5b
...
...
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
def
__init__
(
self
,
config
:
Llama4TextConfig
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
num_experts
=
config
.
num_local_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
layer_id
=
layer_id
,
reduce_results
=
False
,
quant_config
=
quant_config
,
apply_router_weight_on_input
=
True
,
...
...
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
if
is_moe_layer
:
self
.
feed_forward
=
Llama4MoE
(
config
=
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"feed_forward"
,
prefix
),
)
...
...
python/sglang/srt/models/mixtral.py
View file @
9effeb5b
...
...
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
...
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
...
...
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
)
...
...
python/sglang/srt/models/olmoe.py
View file @
9effeb5b
...
...
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
reduce_results
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
layer_id
=
layer_id
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
...
...
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
...
...
python/sglang/srt/models/phimoe.py
View file @
9effeb5b
...
...
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment