Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82cfcd3b
Unverified
Commit
82cfcd3b
authored
Oct 30, 2025
by
Xinyuan Tong
Committed by
GitHub
Oct 31, 2025
Browse files
[Refactor] tuning_fused_moe for MLLM and small refactor (#11224)
Co-authored-by:
Cursor Agent
<
cursoragent@cursor.com
>
parent
6c1a3f0c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
56 deletions
+60
-56
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+60
-56
No files found.
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
82cfcd3b
...
@@ -419,54 +419,73 @@ def get_filename(
...
@@ -419,54 +419,73 @@ def get_filename(
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
args
)
def
_calculate_shard_intermediate_size
(
intermediate_size
:
int
)
->
int
:
# In EP mode, use original intermediate_size; otherwise apply TP sharding
return
(
intermediate_size
if
args
.
ep_size
>
1
else
2
*
intermediate_size
//
args
.
tp_size
)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if
args
.
ep_size
>
1
and
args
.
tp_size
!=
1
:
raise
ValueError
(
f
"When using Expert Parallelism (ep_size=
{
args
.
ep_size
}
), "
f
"tp_size must be set to 1, but got tp_size=
{
args
.
tp_size
}
. "
f
"Please set --tp-size 1 when using --ep-size > 1."
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
True
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
//
args
.
ep_size
# Determine block shape for quantization
block_shape
=
None
if
(
hasattr
(
config
,
"quantization_config"
)
and
"weight_block_size"
in
config
.
quantization_config
):
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
assert
len
(
block_shape
)
==
2
architecture
=
config
.
architectures
[
0
]
# replace config with text_config for encoder-decoder models after getting block_shape and architecture
if
hasattr
(
config
,
"text_config"
):
config
=
config
.
get_text_config
()
if
architecture
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
==
"JambaForCausalLM"
:
)
E
=
config
.
num_experts
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
in
[
)
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3NextForCausalLM"
,
"Qwen3NextForCausalLM"
,
"Qwen3VLMoeForConditionalGeneration"
,
]:
]:
E
=
config
.
num_experts
//
args
.
ep_size
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
)
E
=
(
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
config
.
n_routed_experts
+
(
0
if
args
.
disable_shared_experts_fusion
else
1
)
E
=
(
config
.
n_routed_experts
//
args
.
ep_size
)
+
(
if
architecture
==
"DeepseekV3ForCausalLM"
0
else
config
.
n_routed_experts
if
args
.
disable_shared_experts_fusion
or
config
.
architectures
[
0
]
not
in
[
"DeepseekV3ForCausalLM"
]
else
1
)
)
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
==
"Llama4ForConditionalGeneration"
:
)
E
=
config
.
num_local_experts
+
(
0
if
args
.
disable_shared_experts_fusion
else
1
)
elif
config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
topk
=
config
.
num_experts_per_tok
E
=
config
.
text_config
.
num_local_experts
//
args
.
ep_size
+
(
intermediate_size
=
config
.
intermediate_size
0
if
args
.
disable_shared_experts_fusion
else
1
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
)
elif
architecture
in
[
topk
=
config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
text_config
.
intermediate_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
"Grok1ForCausalLM"
,
"Grok1ForCausalLM"
,
"Grok1ImgGen"
,
"Grok1ImgGen"
,
"Grok1AForCausalLM"
,
"Grok1AForCausalLM"
,
...
@@ -474,10 +493,8 @@ def main(args: argparse.Namespace):
...
@@ -474,10 +493,8 @@ def main(args: argparse.Namespace):
E
=
config
.
num_local_experts
//
args
.
ep_size
E
=
config
.
num_local_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
in
[
)
elif
config
.
architectures
[
0
]
in
[
"BailingMoEForCausalLM"
,
"BailingMoEForCausalLM"
,
"BailingMoeForCausalLM"
,
"BailingMoeForCausalLM"
,
"BailingMoeV2ForCausalLM"
,
"BailingMoeV2ForCausalLM"
,
...
@@ -485,38 +502,25 @@ def main(args: argparse.Namespace):
...
@@ -485,38 +502,25 @@ def main(args: argparse.Namespace):
E
=
config
.
num_experts
//
args
.
ep_size
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
elif
architecture
in
[
"Glm4MoeForCausalLM"
]:
)
E
=
config
.
n_routed_experts
elif
config
.
architectures
[
0
]
in
[
"Glm4MoeForCausalLM"
]:
E
=
config
.
n_routed_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
else
:
else
:
# Default: Mixtral
# Default: Mixtral
E
=
config
.
num_local_experts
//
args
.
ep_size
E
=
config
.
num_local_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
(
shard_intermediate_size
=
_calculate_shard_intermediate_size
(
intermediate_size
)
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
or
config
.
text_
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
dtype
=
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a8
=
args
.
dtype
==
"int8_w8a8"
use_int8_w8a8
=
args
.
dtype
==
"int8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
per_channel_quant
=
args
.
per_channel_quant
per_channel_quant
=
args
.
per_channel_quant
block_shape
=
None
if
(
hasattr
(
config
,
"quantization_config"
)
and
"weight_block_size"
in
config
.
quantization_config
):
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
assert
len
(
block_shape
)
==
2
if
args
.
batch_size
is
None
:
if
args
.
batch_size
is
None
:
batch_sizes
=
[
batch_sizes
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment