Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3b810eb
Unverified
Commit
a3b810eb
authored
Aug 19, 2025
by
mpashkovskiy
Committed by
GitHub
Aug 19, 2025
Browse files
fix: enable multi-GPU Triton fused MoE tuning (#6295)
parent
94959237
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
37 deletions
+43
-37
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+43
-37
No files found.
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
a3b810eb
...
@@ -3,6 +3,7 @@ import argparse
...
@@ -3,6 +3,7 @@ import argparse
import
json
import
json
import
time
import
time
from
datetime
import
datetime
from
datetime
import
datetime
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
import
ray
import
ray
...
@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
...
@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
)
)
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
,
is_rocm
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -245,6 +246,9 @@ class BenchmarkWorker:
...
@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
self
.
seed
=
seed
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
def
benchmark
(
self
,
self
,
...
@@ -283,19 +287,20 @@ class BenchmarkWorker:
...
@@ -283,19 +287,20 @@ class BenchmarkWorker:
)
)
else
:
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
is_rocm
()
else
nullcontext
():
config
,
kernel_time
=
benchmark_config
(
num_tokens
,
config
,
num_experts
,
num_tokens
,
shard_intermediate_size
,
num_experts
,
hidden_size
,
shard_intermediate_size
,
topk
,
hidden_size
,
dtype
,
topk
,
use_fp8_w8a8
,
dtype
,
use_int8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a8
,
block_shape
,
use_int8_w8a16
,
)
block_shape
,
)
return
config
,
kernel_time
return
config
,
kernel_time
def
tune
(
def
tune
(
...
@@ -314,29 +319,30 @@ class BenchmarkWorker:
...
@@ -314,29 +319,30 @@ class BenchmarkWorker:
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
is_rocm
()
else
nullcontext
():
try
:
for
config
in
tqdm
(
search_space
):
kernel_time
=
benchmark_config
(
try
:
config
,
kernel_time
=
benchmark_config
(
num_tokens
,
config
,
num_experts
,
num_tokens
,
shard_intermediate_size
,
num_experts
,
hidden_size
,
shard_intermediate_size
,
topk
,
hidden_size
,
dtype
,
topk
,
use_fp8_w8a8
,
dtype
,
use_int8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a8
,
block_shape
,
use_int8_w8a16
,
num_iters
=
10
,
block_shape
,
)
num_iters
=
10
,
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
)
# Some configurations may be invalid and fail to compile.
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
continue
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
if
kernel_time
<
best_time
:
best_config
=
config
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
assert
best_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment