Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
77830a26
Unverified
Commit
77830a26
authored
Sep 25, 2025
by
lukec
Committed by
GitHub
Sep 25, 2025
Browse files
Add fuse_moe per-channel tune (#10915)
parent
fce17048
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
1 deletion
+22
-1
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+22
-1
No files found.
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
77830a26
...
@@ -47,6 +47,7 @@ def benchmark_config(
...
@@ -47,6 +47,7 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
List
[
int
]
=
None
,
block_shape
:
List
[
int
]
=
None
,
num_iters
:
int
=
100
,
num_iters
:
int
=
100
,
)
->
float
:
)
->
float
:
...
@@ -152,6 +153,7 @@ def benchmark_config(
...
@@ -152,6 +153,7 @@ def benchmark_config(
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
...
@@ -261,6 +263,7 @@ class BenchmarkWorker:
...
@@ -261,6 +263,7 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
List
[
int
],
block_shape
:
List
[
int
],
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -272,7 +275,12 @@ class BenchmarkWorker:
...
@@ -272,7 +275,12 @@ class BenchmarkWorker:
block_n
=
block_shape
[
0
]
if
block_shape
else
0
block_n
=
block_shape
[
0
]
if
block_shape
else
0
block_k
=
block_shape
[
1
]
if
block_shape
else
0
block_k
=
block_shape
[
1
]
if
block_shape
else
0
op_config
=
get_moe_configs
(
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_n
,
block_k
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_n
,
block_k
,
per_channel_quant
,
)
)
if
op_config
is
None
:
if
op_config
is
None
:
config
=
get_default_config
(
config
=
get_default_config
(
...
@@ -299,6 +307,7 @@ class BenchmarkWorker:
...
@@ -299,6 +307,7 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
per_channel_quant
,
block_shape
,
block_shape
,
)
)
return
config
,
kernel_time
return
config
,
kernel_time
...
@@ -314,6 +323,7 @@ class BenchmarkWorker:
...
@@ -314,6 +323,7 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
List
[
int
],
block_shape
:
List
[
int
],
search_space
:
List
[
Dict
[
str
,
int
]],
search_space
:
List
[
Dict
[
str
,
int
]],
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
...
@@ -333,6 +343,7 @@ class BenchmarkWorker:
...
@@ -333,6 +343,7 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
per_channel_quant
,
block_shape
,
block_shape
,
num_iters
=
10
,
num_iters
=
10
,
)
)
...
@@ -373,6 +384,7 @@ def save_configs(
...
@@ -373,6 +384,7 @@ def save_configs(
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
List
[
int
],
block_shape
:
List
[
int
],
)
->
None
:
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype_str
=
get_config_dtype_str
(
...
@@ -389,6 +401,7 @@ def save_configs(
...
@@ -389,6 +401,7 @@ def save_configs(
shard_intermediate_size
//
2
,
shard_intermediate_size
//
2
,
dtype_str
,
dtype_str
,
block_shape
,
block_shape
,
per_channel_quant
,
)
)
print
(
f
"Writing best config to
{
filename
}
..."
)
print
(
f
"Writing best config to
{
filename
}
..."
)
...
@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
...
@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a8
=
args
.
dtype
==
"int8_w8a8"
use_int8_w8a8
=
args
.
dtype
==
"int8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
per_channel_quant
=
args
.
per_channel_quant
block_shape
=
None
block_shape
=
None
if
(
if
(
hasattr
(
config
,
"quantization_config"
)
hasattr
(
config
,
"quantization_config"
)
...
@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
...
@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
per_channel_quant
,
block_shape
,
block_shape
,
search_space
,
search_space
,
)
)
...
@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
...
@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
per_channel_quant
,
block_shape
,
block_shape
,
)
)
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
...
@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
...
@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
per_channel_quant
,
block_shape
,
block_shape
,
)
)
for
batch_size
in
batch_sizes
for
batch_size
in
batch_sizes
...
@@ -603,6 +620,10 @@ if __name__ == "__main__":
...
@@ -603,6 +620,10 @@ if __name__ == "__main__":
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
,
"int8_w8a8"
],
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
,
"int8_w8a8"
],
default
=
"auto"
,
default
=
"auto"
,
)
)
parser
.
add_argument
(
"--per-channel-quant"
,
action
=
"store_true"
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment