Unverified Commit 3844feb9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add a unittest for fused_moe (#2416)

parent 27f7bed7
...@@ -10,7 +10,7 @@ Example usage: ...@@ -10,7 +10,7 @@ Example usage:
```bash ```bash
# Tune Qwen2-57B with FP8 and TP=4 # Tune Qwen2-57B with FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ --model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \ --tp-size 4 \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--tune --tune
...@@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ...@@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
# Compare with FP8 mode for Qwen2-57B # Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ --model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8 --use-fp8
# Compare with custom TP size # Compare with custom TP size
...@@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ...@@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
``` ```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
...@@ -15,6 +15,7 @@ suites = { ...@@ -15,6 +15,7 @@ suites = {
"test_double_sparsity.py", "test_double_sparsity.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_fused_moe.py",
"test_get_weights_by_name.py", "test_get_weights_by_name.py",
"test_gguf.py", "test_gguf.py",
"test_input_embeddings.py", "test_input_embeddings.py",
......
import unittest
import torch
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe
class TestFusedMOE(unittest.TestCase):
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
def torch_naive_moe(self, a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[
i
].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
if use_fp8_w8a8:
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
capability = torch.cuda.get_device_capability()
if not (capability[0] >= 9 or capability == (8, 9)):
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
score = torch.randn((m, e), device="cuda", dtype=dtype)
w1_scale = torch.randn(e, dtype=torch.float32, device="cuda")
w2_scale = torch.randn(e, dtype=torch.float32, device="cuda")
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
sglang_output = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
vllm_output = fused_moe_vllm(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0)
else:
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
def test_various_configurations(self):
m_values = [1, 33, 64, 222, 1024 * 128]
n_values = [128, 1024, 2048]
k_values = [128, 511, 1024]
dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True]
for m in m_values:
for n in n_values:
for k in k_values:
for e in self.NUM_EXPERTS:
for topk in self.TOP_KS:
for dtype in dtypes:
for use_fp8_w8a8 in fp8_modes:
with self.subTest(
m=m,
n=n,
k=k,
e=e,
topk=topk,
dtype=dtype,
fp8=use_fp8_w8a8,
):
self._test_case(
m,
n,
k,
e,
topk,
dtype,
use_fp8_w8a8=use_fp8_w8a8,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment