"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7aac77affa17b6b504b0a406aacb471c5226b36d"
Unverified Commit 3c8ac78d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

optimize test_fused_moe style (#3268)

parent 455bfe8d
import unittest import unittest
import torch import torch
import torch.nn.functional as F
from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase): ...@@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase):
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6] TOP_KS = [2, 6]
@staticmethod
def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
"""Create a random CUDA tensor
Args:
shape: Tensor shape
dtype: Data type
mean: Mean value
std: Standard deviation
Returns:
torch.Tensor: Randomly initialized CUDA tensor
"""
return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)
def get_tolerance(self, dtype):
"""Get tolerance values for different data types
Args:
dtype: Data type
Returns:
tuple: (relative tolerance, absolute tolerance)
"""
if dtype == torch.float32:
return 1e-3, 1e-5
elif dtype in [torch.float16, torch.bfloat16]:
return 1e-1, 1e-2
else:
return 1e-2, 1e-2 # Default values for other types
def torch_naive_moe(self, a, w1, w2, score, topk): def torch_naive_moe(self, a, w1, w2, score, topk):
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
...@@ -30,23 +63,25 @@ class TestFusedMOE(unittest.TestCase): ...@@ -30,23 +63,25 @@ class TestFusedMOE(unittest.TestCase):
).sum(dim=1) ).sum(dim=1)
def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
rtol, atol = self.get_tolerance(dtype)
if use_fp8_w8a8: if use_fp8_w8a8:
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
if not (capability[0] >= 9 or capability == (8, 9)): if not (capability[0] >= 9 or capability == (8, 9)):
return return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = self.create_random_cuda_tensor((m, k), dtype)
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = self.create_random_cuda_tensor((e, k, n), dtype)
w1 = w1.to(torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn)
score = torch.randn((m, e), device="cuda", dtype=dtype) score = self.create_random_cuda_tensor((m, e), dtype)
w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") w1_scale = self.create_random_cuda_tensor(e, torch.float32)
w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") w2_scale = self.create_random_cuda_tensor(e, torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") a1_scale = self.create_random_cuda_tensor(1, torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") a2_scale = self.create_random_cuda_tensor(1, torch.float32)
sglang_output = fused_moe( sglang_output = fused_moe(
a, a,
...@@ -76,17 +111,19 @@ class TestFusedMOE(unittest.TestCase): ...@@ -76,17 +111,19 @@ class TestFusedMOE(unittest.TestCase):
a2_scale=a2_scale, a2_scale=a2_scale,
) )
torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol)
else: else:
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = self.create_random_cuda_tensor((m, k), dtype)
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = self.create_random_cuda_tensor((e, k, n), dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype) score = self.create_random_cuda_tensor((m, e), dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(
triton_output, torch_output, rtol=rtol, atol=atol
)
def test_various_configurations(self): def test_various_configurations(self):
m_values = [1, 33, 64, 222, 1024 * 128] m_values = [1, 33, 64, 222, 1024 * 128]
...@@ -95,31 +132,45 @@ class TestFusedMOE(unittest.TestCase): ...@@ -95,31 +132,45 @@ class TestFusedMOE(unittest.TestCase):
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True] fp8_modes = [False, True]
for m in m_values: # Calculate total number of tests
for n in n_values: total_tests = (
for k in k_values: len(m_values)
for e in self.NUM_EXPERTS: * len(n_values)
for topk in self.TOP_KS: * len(k_values)
for dtype in dtypes: * len(self.NUM_EXPERTS)
for use_fp8_w8a8 in fp8_modes: * len(self.TOP_KS)
with self.subTest( * len(dtypes)
m=m, * len(fp8_modes)
n=n, )
k=k,
e=e, # Create progress bar
topk=topk, with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
dtype=dtype, for m in m_values:
fp8=use_fp8_w8a8, for n in n_values:
): for k in k_values:
self._test_case( for e in self.NUM_EXPERTS:
m, for topk in self.TOP_KS:
n, for dtype in dtypes:
k, for use_fp8_w8a8 in fp8_modes:
e, with self.subTest(
topk, m=m,
dtype, n=n,
use_fp8_w8a8=use_fp8_w8a8, k=k,
) e=e,
topk=topk,
dtype=dtype,
fp8=use_fp8_w8a8,
):
self._test_case(
m,
n,
k,
e,
topk,
dtype,
use_fp8_w8a8=use_fp8_w8a8,
)
pbar.update(1)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment