Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9effeb5b
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "256801e9737919ceb018898c0ecb69012e660a19"
Unverified
Commit
9effeb5b
authored
Jul 29, 2025
by
Cheng Wan
Committed by
GitHub
Jul 29, 2025
Browse files
Support EPLB in FusedMoE (#8448)
parent
1992ef9b
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
107 additions
and
11 deletions
+107
-11
python/sglang/srt/eplb/expert_distribution.py
python/sglang/srt/eplb/expert_distribution.py
+5
-0
python/sglang/srt/eplb/expert_location.py
python/sglang/srt/eplb/expert_location.py
+17
-6
python/sglang/srt/eplb/expert_location_dispatch.py
python/sglang/srt/eplb/expert_location_dispatch.py
+1
-0
python/sglang/srt/eplb/expert_location_updater.py
python/sglang/srt/eplb/expert_location_updater.py
+2
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+16
-3
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+44
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-0
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+3
-1
python/sglang/srt/models/granitemoe.py
python/sglang/srt/models/granitemoe.py
+3
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-0
python/sglang/srt/models/hunyuan.py
python/sglang/srt/models/hunyuan.py
+1
-0
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+3
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+3
-0
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+3
-0
python/sglang/srt/models/phimoe.py
python/sglang/srt/models/phimoe.py
+1
-0
No files found.
python/sglang/srt/eplb/expert_distribution.py
View file @
9effeb5b
...
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
...
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
rank
:
int
,
rank
:
int
,
):
):
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
assert
(
expert_location_metadata
is
not
None
),
"ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return
_ExpertDistributionRecorderReal
(
return
_ExpertDistributionRecorderReal
(
server_args
,
expert_location_metadata
,
rank
server_args
,
expert_location_metadata
,
rank
)
)
...
...
python/sglang/srt/eplb/expert_location.py
View file @
9effeb5b
...
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
...
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Trivial location - logical expert i corresponds to physical expert i"""
"""Trivial location - logical expert i corresponds to physical expert i"""
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
num_physical_experts
=
common
[
"num_physical_experts"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_layers
=
model_config_for_expert_location
.
num_layers
num_layers
=
model_config_for_expert_location
.
num_layers
...
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
...
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
physical_to_logical_map
,
physical_to_logical_map
,
...
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
...
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
logical_count
=
logical_count
.
to
(
server_args
.
device
)
logical_count
=
logical_count
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
if
common
is
None
:
return
None
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_groups
=
model_config_for_expert_location
.
num_groups
num_groups
=
model_config_for_expert_location
.
num_groups
...
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
...
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
)
)
if
model_config_for_expert_location
is
None
:
return
None
num_physical_experts
=
(
num_physical_experts
=
(
model_config_for_expert_location
.
num_logical_experts
model_config_for_expert_location
.
num_logical_experts
+
server_args
.
ep_num_redundant_experts
+
server_args
.
ep_num_redundant_experts
...
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
...
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts
:
int
num_logical_experts
:
int
num_groups
:
Optional
[
int
]
=
None
num_groups
:
Optional
[
int
]
=
None
@
staticmethod
def
init_dummy
():
return
ModelConfigForExpertLocation
(
num_layers
=
1
,
num_logical_experts
=
1
)
@
staticmethod
@
staticmethod
def
from_model_config
(
model_config
:
ModelConfig
):
def
from_model_config
(
model_config
:
ModelConfig
):
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
...
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
...
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
model_config
.
hf_config
model_config
.
hf_config
)
)
else
:
else
:
return
ModelConfigForExpertLocation
.
init_dummy
()
return
None
def
compute_initial_expert_location_metadata
(
def
compute_initial_expert_location_metadata
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
server_args
:
ServerArgs
,
model_config
:
ModelConfig
)
->
ExpertLocationMetadata
:
)
->
Optional
[
ExpertLocationMetadata
]
:
data
=
server_args
.
init_expert_location
data
=
server_args
.
init_expert_location
if
data
==
"trivial"
:
if
data
==
"trivial"
:
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
...
...
python/sglang/srt/eplb/expert_location_dispatch.py
View file @
9effeb5b
...
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
...
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def
init_new
(
cls
,
layer_id
:
int
):
def
init_new
(
cls
,
layer_id
:
int
):
ep_dispatch_algorithm
=
global_server_args_dict
[
"ep_dispatch_algorithm"
]
ep_dispatch_algorithm
=
global_server_args_dict
[
"ep_dispatch_algorithm"
]
expert_location_metadata
=
get_global_expert_location_metadata
()
expert_location_metadata
=
get_global_expert_location_metadata
()
assert
expert_location_metadata
is
not
None
if
ep_dispatch_algorithm
is
None
:
if
ep_dispatch_algorithm
is
None
:
return
None
return
None
...
...
python/sglang/srt/eplb/expert_location_updater.py
View file @
9effeb5b
...
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
...
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
assert
old_expert_location_metadata
is
not
None
_update_expert_weights
(
_update_expert_weights
(
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
old_expert_location_metadata
=
old_expert_location_metadata
,
old_expert_location_metadata
=
old_expert_location_metadata
,
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
9effeb5b
...
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
...
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
...
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
top_k
=
top_k
,
top_k
=
top_k
,
num_fused_shared_experts
=
num_fused_shared_experts
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
...
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
shard_id
:
str
,
shard_id
:
str
,
expert_id
:
int
,
expert_id
:
int
,
)
->
None
:
)
->
None
:
physical_expert_ids
=
(
global_expert_location_metadata
=
get_global_expert_location_metadata
()
get_global_expert_location_metadata
().
logical_to_all_physical
(
if
global_expert_location_metadata
is
None
:
self
.
layer_id
,
expert_id
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
)
return
physical_expert_ids
=
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
)
for
physical_expert_id
in
physical_expert_ids
:
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
self
.
_weight_loader_physical
(
...
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
...
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
...
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
num_fused_shared_experts
=
num_fused_shared_experts
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
9effeb5b
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
layer_id
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
reduce_results
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
...
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
layer_id
=
layer_id
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
tp_size
=
(
self
.
tp_size
=
(
...
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
)
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map
=
None
self
.
expert_map
=
None
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
...
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
...
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
expert_id
:
int
,
expert_id
:
int
,
)
->
None
:
)
->
None
:
global_expert_location_metadata
=
get_global_expert_location_metadata
()
if
global_expert_location_metadata
is
None
:
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
return
if
expert_id
>=
self
.
num_experts
-
self
.
num_fused_shared_experts
:
# This is a shared expert.
physical_expert_ids
=
[
expert_id
]
else
:
physical_expert_ids
=
(
global_expert_location_metadata
.
logical_to_all_physical
(
self
.
layer_id
,
expert_id
)
)
for
physical_expert_id
in
physical_expert_ids
:
self
.
_weight_loader_physical
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
physical_expert_id
,
)
def
_weight_loader_physical
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9effeb5b
...
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
9effeb5b
...
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
global_server_args_dict
[
"enable_deepep_moe"
]
global_server_args_dict
[
"enable_deepep_moe"
]
or
global_server_args_dict
[
"enable_ep_moe"
]
or
global_server_args_dict
[
"enable_ep_moe"
]
):
):
disable_reason
=
"Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
disable_reason
=
"Deepseek
and
GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
...
...
python/sglang/srt/models/granitemoe.py
View file @
9effeb5b
...
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
...
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
...
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
top_k
=
top_k
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
reduce_results
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
...
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
)
)
...
...
python/sglang/srt/models/grok.py
View file @
9effeb5b
...
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
...
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_id
:
int
,
num_experts
:
int
,
num_experts
:
int
,
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
...
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
...
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
)
)
self
.
block_sparse_moe
=
Grok1MoE
(
self
.
block_sparse_moe
=
Grok1MoE
(
config
=
config
,
config
=
config
,
layer_id
=
layer_id
,
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
...
python/sglang/srt/models/hunyuan.py
View file @
9effeb5b
...
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
...
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
reduce_results
=
False
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
...
...
python/sglang/srt/models/llama4.py
View file @
9effeb5b
...
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
...
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Llama4TextConfig
,
config
:
Llama4TextConfig
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
...
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
...
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
intermediate_size
=
intermediate_size_moe
,
layer_id
=
layer_id
,
reduce_results
=
False
,
reduce_results
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
apply_router_weight_on_input
=
True
,
apply_router_weight_on_input
=
True
,
...
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
...
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
if
is_moe_layer
:
if
is_moe_layer
:
self
.
feed_forward
=
Llama4MoE
(
self
.
feed_forward
=
Llama4MoE
(
config
=
config
,
config
=
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"feed_forward"
,
prefix
),
prefix
=
add_prefix
(
"feed_forward"
,
prefix
),
)
)
...
...
python/sglang/srt/models/mixtral.py
View file @
9effeb5b
...
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
...
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
top_k
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
...
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
...
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
)
)
...
...
python/sglang/srt/models/olmoe.py
View file @
9effeb5b
...
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
...
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
...
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
layer_id
=
layer_id
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
)
...
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
...
...
python/sglang/srt/models/phimoe.py
View file @
9effeb5b
...
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
...
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
self
.
experts
=
FusedMoE
(
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
layer_id
=
layer_id
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
reduce_results
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment