Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1ed1abfd
"test/vscode:/vscode.git/clone" did not exist on "69ce4523390c673ee48a96946868b7268eec437d"
Unverified
Commit
1ed1abfd
authored
Oct 30, 2025
by
Chen1022
Committed by
GitHub
Oct 30, 2025
Browse files
feat: add EP support in tuning (#12012)
parent
ecb9fa14
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
21 deletions
+41
-21
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+41
-21
No files found.
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
1ed1abfd
...
@@ -421,69 +421,88 @@ def main(args: argparse.Namespace):
...
@@ -421,69 +421,88 @@ def main(args: argparse.Namespace):
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
True
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
//
args
.
ep_size
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3NextForCausalLM"
,
"Qwen3NextForCausalLM"
,
]:
]:
E
=
config
.
num_experts
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
E
=
(
E
=
(
config
.
n_routed_experts
//
args
.
ep_size
)
+
(
config
.
n_routed_experts
+
(
0
if
args
.
disable_shared_experts_fusion
else
1
)
0
if
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
]
if
args
.
disable_shared_experts_fusion
else
config
.
n_routed_experts
or
config
.
architectures
[
0
]
not
in
[
"DeepseekV3ForCausalLM"
]
else
1
)
)
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
elif
config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
E
=
config
.
text_config
.
num_local_experts
+
(
E
=
config
.
text_config
.
num_local_experts
//
args
.
ep_size
+
(
0
if
args
.
disable_shared_experts_fusion
else
1
0
if
args
.
disable_shared_experts_fusion
else
1
)
)
topk
=
config
.
text_config
.
num_experts_per_tok
topk
=
config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
text_config
.
intermediate_size
intermediate_size
=
config
.
text_config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
elif
config
.
architectures
[
0
]
in
[
"Grok1ForCausalLM"
,
"Grok1ForCausalLM"
,
"Grok1ImgGen"
,
"Grok1ImgGen"
,
"Grok1AForCausalLM"
,
"Grok1AForCausalLM"
,
]:
]:
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
elif
config
.
architectures
[
0
]
in
[
"BailingMoEForCausalLM"
,
"BailingMoEForCausalLM"
,
"BailingMoeForCausalLM"
,
"BailingMoeForCausalLM"
,
"BailingMoeV2ForCausalLM"
,
"BailingMoeV2ForCausalLM"
,
]:
]:
E
=
config
.
num_experts
E
=
config
.
num_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
elif
config
.
architectures
[
0
]
in
[
"Glm4MoeForCausalLM"
]:
elif
config
.
architectures
[
0
]
in
[
"Glm4MoeForCausalLM"
]:
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
else
:
else
:
# Default: Mixtral
# Default: Mixtral
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
//
args
.
ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
(
2
*
intermediate_size
//
(
args
.
tp_size
//
args
.
ep_size
)
)
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
or
config
.
text_config
.
hidden_size
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
or
config
.
text_config
.
hidden_size
dtype
=
config
.
torch_dtype
dtype
=
config
.
torch_dtype
...
@@ -626,6 +645,7 @@ if __name__ == "__main__":
...
@@ -626,6 +645,7 @@ if __name__ == "__main__":
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
)
parser
.
add_argument
(
"--tp-size"
,
"--tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--tp-size"
,
"--tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--ep-size"
,
"--ep"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dtype"
,
"--dtype"
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment