Unverified Commit 87eddedf authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[ci] fix ci test fused_moe op (#5102)

parent 40652482
...@@ -76,6 +76,7 @@ suites = { ...@@ -76,6 +76,7 @@ suites = {
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_hicache.py", 60), TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90), TestFile("test_hicache_mla.py", 90),
TestFile("test_fused_moe.py", 30),
TestFile("test_triton_moe_channel_fp8_kernel.py", 25), TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
......
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
...@@ -45,7 +44,18 @@ class TestFusedMOE(CustomTestCase): ...@@ -45,7 +44,18 @@ class TestFusedMOE(CustomTestCase):
else: else:
return 1e-2, 1e-2 # Default values for other types return 1e-2, 1e-2 # Default values for other types
def torch_naive_moe(self, a, w1, w2, score, topk): def torch_naive_moe(
self,
a,
w1,
w2,
score,
topk,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
...@@ -53,12 +63,30 @@ class TestFusedMOE(CustomTestCase): ...@@ -53,12 +63,30 @@ class TestFusedMOE(CustomTestCase):
topk_weight, topk_ids = torch.topk(score, topk) topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1) topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
if w1.dtype == torch.float8_e4m3fn:
w1_compute = w1.to(a.dtype)
w2_compute = w2.to(a.dtype)
if w1_scale is not None:
w1_compute = (w1_compute * w1_scale.view(-1, 1, 1)).to(a.dtype)
if w2_scale is not None:
w2_compute = (w2_compute * w2_scale.view(-1, 1, 1)).to(a.dtype)
if a1_scale is not None:
a = (a * a1_scale).to(a.dtype)
if a2_scale is not None:
a = (a * a2_scale).to(a.dtype)
else:
w1_compute = w1
w2_compute = w2
for i in range(w1_compute.shape[0]):
mask = topk_ids == i mask = topk_ids == i
if mask.sum(): if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ out[mask] = SiluAndMul()(
i a[mask] @ w1_compute[i].transpose(0, 1)
].transpose(0, 1) ) @ w2_compute[i].transpose(0, 1)
return ( return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1) ).sum(dim=1)
...@@ -98,21 +126,12 @@ class TestFusedMOE(CustomTestCase): ...@@ -98,21 +126,12 @@ class TestFusedMOE(CustomTestCase):
a2_scale=a2_scale, a2_scale=a2_scale,
) )
vllm_output = fused_moe_vllm( torch_output = self.torch_naive_moe(
a, a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
w1, )
w2, torch.testing.assert_close(
score, sglang_output, torch_output, rtol=rtol, atol=atol
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
) )
torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol)
else: else:
a = self.create_random_cuda_tensor((m, k), dtype) a = self.create_random_cuda_tensor((m, k), dtype)
...@@ -127,8 +146,8 @@ class TestFusedMOE(CustomTestCase): ...@@ -127,8 +146,8 @@ class TestFusedMOE(CustomTestCase):
) )
def test_various_configurations(self): def test_various_configurations(self):
m_values = [1, 33, 64, 222, 1024 * 128] m_values = [1, 33, 64, 222]
n_values = [128, 1024, 2048] n_values = [128, 1024]
k_values = [128, 511, 1024] k_values = [128, 511, 1024]
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True] fp8_modes = [False, True]
...@@ -171,6 +190,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -171,6 +190,7 @@ class TestFusedMOE(CustomTestCase):
dtype, dtype,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
) )
torch.cuda.empty_cache()
pbar.update(1) pbar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment