Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
968e1818
Unverified
Commit
968e1818
authored
Aug 18, 2025
by
Yuan Luo
Committed by
GitHub
Aug 18, 2025
Browse files
Fix triton_fused_moe unit test and benchmark (#9276)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
d08663ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
8 deletions
+41
-8
benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
...els/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
+24
-7
test/srt/test_triton_fused_moe.py
test/srt/test_triton_fused_moe.py
+17
-1
No files found.
benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
View file @
968e1818
...
@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
...
@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
)
)
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKConfig
,
select_experts
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
@@ -80,13 +82,26 @@ def fused_moe_triton_api(
...
@@ -80,13 +82,26 @@ def fused_moe_triton_api(
input_gating
,
input_gating
,
topk
,
topk
,
):
):
topk_op
=
TopK
(
top_k
=
topk
,
renormalize
=
False
,
use_grouped_topk
=
False
,
)
topk_op
.
use_triton_kernels
=
True
triton_topk_output
=
topk_op
.
forward_cuda
(
hidden_states
=
x
,
router_logits
=
input_gating
,
)
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
False
,
)
return
triton_kernel_moe_forward
(
return
triton_kernel_moe_forward
(
x
,
x
,
w1
,
w1
,
w2
,
w2
,
input_gating
,
triton_topk_output
,
topk
,
moe_runner_config
,
renormalize
=
False
,
)
)
...
@@ -103,14 +118,16 @@ def fused_moe_sglang_api(
...
@@ -103,14 +118,16 @@ def fused_moe_sglang_api(
a2_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
block_shape
=
None
,
):
):
topk_output
=
select_experts
(
hidden_states
=
x
,
router_logits
=
input_gating
,
topk_config
=
TopKConfig
(
top_k
=
topk
,
renormalize
=
False
),
)
return
fused_moe_sglang
(
return
fused_moe_sglang
(
x
,
x
,
w1
,
w1
,
w2
,
w2
,
input_gating
,
topk_output
,
topk
,
renormalize
=
False
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
...
...
test/srt/test_triton_fused_moe.py
View file @
968e1818
...
@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
)
)
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
...
@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
w2_tri
=
w2_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
w2_tri
=
w2_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
topk_op
=
TopK
(
top_k
=
topk
,
renormalize
=
False
,
use_grouped_topk
=
False
,
)
topk_op
.
use_triton_kernels
=
True
triton_topk_output
=
topk_op
.
forward_cuda
(
hidden_states
=
a
,
router_logits
=
score
,
)
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
False
,
)
triton_output
=
triton_kernel_moe_forward
(
triton_output
=
triton_kernel_moe_forward
(
a
,
w1_tri
,
w2_tri
,
score
,
topk
,
renormalize
=
False
a
,
w1_tri
,
w2_tri
,
triton_topk_output
,
moe_runner_config
)
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment