Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9effeb5b
Unverified
Commit
9effeb5b
authored
Jul 29, 2025
by
Cheng Wan
Committed by
GitHub
Jul 29, 2025
Browse files
Support EPLB in FusedMoE (#8448)
parent
1992ef9b
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
107 additions
and
11 deletions
+107
-11
python/sglang/srt/eplb/expert_distribution.py
python/sglang/srt/eplb/expert_distribution.py
+5
-0
python/sglang/srt/eplb/expert_location.py
python/sglang/srt/eplb/expert_location.py
+17
-6
python/sglang/srt/eplb/expert_location_dispatch.py
python/sglang/srt/eplb/expert_location_dispatch.py
+1
-0
python/sglang/srt/eplb/expert_location_updater.py
python/sglang/srt/eplb/expert_location_updater.py
+2
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+16
-3
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+44
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-0
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+3
-1
python/sglang/srt/models/granitemoe.py
python/sglang/srt/models/granitemoe.py
+3
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-0
python/sglang/srt/models/hunyuan.py
python/sglang/srt/models/hunyuan.py
+1
-0
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+3
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+3
-0
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+3
-0
python/sglang/srt/models/phimoe.py
python/sglang/srt/models/phimoe.py
+1
-0
No files found.
python/sglang/srt/eplb/expert_distribution.py
View file @
9effeb5b
...
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
...
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
rank
:
int
,
rank
:
int
,
):
):
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
assert
(
expert_location_metadata
is
not
None
),
"ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return
_ExpertDistributionRecorderReal
(
return
_ExpertDistributionRecorderReal
(
server_args
,
expert_location_metadata
,
rank
server_args
,
expert_location_metadata
,
rank
)
)
...
...
python/sglang/srt/eplb/expert_location.py
View file @
9effeb5b
...
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
...
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Trivial location - logical expert i corresponds to physical expert i"""
"""Trivial location - logical expert i corresponds to physical expert i"""
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
num_physical_experts
=
common
[
"num_physical_experts"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_layers
=
model_config_for_expert_location
.
num_layers
num_layers
=
model_config_for_expert_location
.
num_layers
...
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
...
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
physical_to_logical_map
,
physical_to_logical_map
,
...
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
...
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
logical_count
=
logical_count
.
to
(
server_args
.
device
)
logical_count
=
logical_count
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_groups
=
model_config_for_expert_location
.
num_groups
num_groups
=
model_config_for_expert_location
.
num_groups
...
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
...
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
)
)
if
model_config_for_expert_location
is
None
:
return
None
num_physical_experts
=
(
num_physical_experts
=
(
model_config_for_expert_location
.
num_logical_experts
model_config_for_expert_location
.
num_logical_experts
+
server_args
.
ep_num_redundant_experts
+
server_args
.
ep_num_redundant_experts
...
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
...
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts
:
int
num_logical_experts
:
int
num_groups
:
Optional
[
int
]
=
None
num_groups
:
Optional
[
int
]
=
None
@
staticmethod
def
init_dummy
():
return
ModelConfigForExpertLocation
(
num_layers
=
1
,
num_logical_experts
=
1
)
@
staticmethod
@
staticmethod
def
from_model_config
(
model_config
:
ModelConfig
):
def
from_model_config
(
model_config
:
ModelConfig
):
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
...
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
...
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
model_config
.
hf_config
model_config
.
hf_config
)
)
else
:
else
:
return
ModelConfigForExpertLocation
.
init_dummy
()
return
None
def
compute_initial_expert_location_metadata
(
def
compute_initial_expert_location_metadata
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
server_args
:
ServerArgs
,
model_config
:
ModelConfig
)
->
ExpertLocationMetadata
:
)
->
Optional
[
ExpertLocationMetadata
]
:
data
=
server_args
.
init_expert_location
data
=
server_args
.
init_expert_location
if
data
==
"trivial"
:
if
data
==
"trivial"
:
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
...
...
python/sglang/srt/eplb/expert_location_dispatch.py
View file @
9effeb5b
...
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
...
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def
init_new
(
cls
,
layer_id
:
int
):
def
init_new
(
cls
,
layer_id
:
int
):
ep_dispatch_algorithm
=
global_server_args_dict
[
"ep_dispatch_algorithm"
]
ep_dispatch_algorithm
=
global_server_args_dict
[
"ep_dispatch_algorithm"
]
expert_location_metadata
=
get_global_expert_location_metadata
()
expert_location_metadata
=
get_global_expert_location_metadata
()
assert
expert_location_metadata
is
not
None
if
ep_dispatch_algorithm
is
None
:
if
ep_dispatch_algorithm
is
None
:
return
None
return
None
...
...
python/sglang/srt/eplb/expert_location_updater.py
View file @
9effeb5b
...
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
...
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
assert
old_expert_location_metadata
is
not
None
_update_expert_weights
(
_update_expert_weights
(
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
old_expert_location_metadata
=
old_expert_location_metadata
,
old_expert_location_metadata
=
old_expert_location_metadata
,
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
9effeb5b
...
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
...
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
...
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
top_k
=
top_k
,
top_k
=
top_k
,
num_fused_shared_experts
=
num_fused_shared_experts
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
...
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
shard_id
:
str
,
shard_id
:
str
,
expert_id
:
int
,
expert_id
:
int
,
)
->
None
:
)
->
None
:
physical_expert_ids
=
(
global_expert_location_metadata
=
get_global_expert_location_metadata
()
get_global_expert_location_metadata
().
logical_to_all_physical
(
if
global_expert_location_metadata
is
None
:
self
.
layer_id
,
expert_id
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
)
return
physical_expert_ids
=
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
)
for
physical_expert_id
in
physical_expert_ids
:
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
self
.
_weight_loader_physical
(
...
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
...
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
...
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
num_fused_shared_experts
=
num_fused_shared_experts
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
9effeb5b
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
layer_id
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
reduce_results
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
...
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
layer_id
=
layer_id
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
tp_size
=
(
self
.
tp_size
=
(
...
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
)
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map
=
None
self
.
expert_map
=
None
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
...
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
...
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
expert_id
:
int
,
expert_id
:
int
,
)
->
None
:
)
->
None
:
global_expert_location_metadata
=
get_global_expert_location_metadata
()
if
global_expert_location_metadata
is
None
:
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
return
if
expert_id
>=
self
.
num_experts
-
self
.
num_fused_shared_experts
:
# This is a shared expert.
physical_expert_ids
=
[
expert_id
]
else
:
physical_expert_ids
=
(
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
)
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
physical_expert_id
,
)
def
_weight_loader_physical
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9effeb5b
...
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
9effeb5b
...
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
global_server_args_dict
[
"enable_deepep_moe"
]
global_server_args_dict
[
"enable_deepep_moe"
]
or
global_server_args_dict
[
"enable_ep_moe"
]
or
global_server_args_dict
[
"enable_ep_moe"
]
):
):
disable_reason
=
"Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
disable_reason
=
"Deepseek
and
GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/granitemoe.py
View file @
9effeb5b
...
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
...
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
...
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
top_k
=
top_k
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
reduce_results
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
...
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
)
)
...
...
python/sglang/srt/models/grok.py
View file @
9effeb5b
...
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
...
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_id
:
int
,
num_experts
:
int
,
num_experts
:
int
,
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
...
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
...
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
)
)
self
.
block_sparse_moe
=
Grok1MoE
(
self
.
block_sparse_moe
=
Grok1MoE
(
config
=
config
,
config
=
config
,
layer_id
=
layer_id
,
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
...
python/sglang/srt/models/hunyuan.py
View file @
9effeb5b
...
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
...
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
reduce_results
=
False
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
...
...
python/sglang/srt/models/llama4.py
View file @
9effeb5b
...
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
...
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Llama4TextConfig
,
config
:
Llama4TextConfig
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
...
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
...
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
intermediate_size
=
intermediate_size_moe
,
layer_id
=
layer_id
,
reduce_results
=
False
,
reduce_results
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
apply_router_weight_on_input
=
True
,
apply_router_weight_on_input
=
True
,
...
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
...
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
if
is_moe_layer
:
if
is_moe_layer
:
self
.
feed_forward
=
Llama4MoE
(
self
.
feed_forward
=
Llama4MoE
(
config
=
config
,
config
=
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"feed_forward"
,
prefix
),
prefix
=
add_prefix
(
"feed_forward"
,
prefix
),
)
)
...
...
python/sglang/srt/models/mixtral.py
View file @
9effeb5b
...
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
...
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
...
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
...
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
)
)
...
...
python/sglang/srt/models/olmoe.py
View file @
9effeb5b
...
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
...
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
...
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
layer_id
=
layer_id
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
)
...
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
...
...
python/sglang/srt/models/phimoe.py
View file @
9effeb5b
...
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
...
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
self
.
experts
=
FusedMoE
(
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
reduce_results
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment