Unverified Commit 968e1818 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Fix triton_fused_moe unit test and benchmark (#9276)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent d08663ee
...@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int):
...@@ -80,13 +82,26 @@ def fused_moe_triton_api( ...@@ -80,13 +82,26 @@ def fused_moe_triton_api(
input_gating, input_gating,
topk, topk,
): ):
topk_op = TopK(
top_k=topk,
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
triton_topk_output = topk_op.forward_cuda(
hidden_states=x,
router_logits=input_gating,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
)
return triton_kernel_moe_forward( return triton_kernel_moe_forward(
x, x,
w1, w1,
w2, w2,
input_gating, triton_topk_output,
topk, moe_runner_config,
renormalize=False,
) )
...@@ -103,14 +118,16 @@ def fused_moe_sglang_api( ...@@ -103,14 +118,16 @@ def fused_moe_sglang_api(
a2_scale=None, a2_scale=None,
block_shape=None, block_shape=None,
): ):
topk_output = select_experts(
hidden_states=x,
router_logits=input_gating,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
return fused_moe_sglang( return fused_moe_sglang(
x, x,
w1, w1,
w2, w2,
input_gating, topk_output,
topk,
renormalize=False,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
......
...@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase): ...@@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase):
w2_tri = w2_tri.transpose(-2, -1).contiguous() w2_tri = w2_tri.transpose(-2, -1).contiguous()
score = self.create_random_cuda_tensor((m, e), dtype) score = self.create_random_cuda_tensor((m, e), dtype)
topk_op = TopK(
top_k=topk,
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
triton_topk_output = topk_op.forward_cuda(
hidden_states=a,
router_logits=score,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
)
triton_output = triton_kernel_moe_forward( triton_output = triton_kernel_moe_forward(
a, w1_tri, w2_tri, score, topk, renormalize=False a, w1_tri, w2_tri, triton_topk_output, moe_runner_config
) )
torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol) torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment