Unverified Commit ffb2cd6b authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize `moe_align_block_size` CUDA kernel (#19572)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent ca94d7fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import itertools
import torch
import triton
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton,
)
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
return torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
]
)
def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
"""
Verifies vllm vs. Triton
"""
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
# 1. malloc space for triton and vllm
# malloc enough space (max_num_tokens_padded) for the sorted ids
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
expert_ids_triton = torch.zeros(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
# 2. run implementations
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
# 3. compare results
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
):
print("✅ Triton and VLLM implementations match.")
else:
print("❌ Triton and VLLM implementations DO NOT match.")
print("Triton expert_ids:", expert_ids_triton)
print("VLLM expert_ids:", expert_ids_vllm)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
# test configurations
num_tokens_range = [1, 16, 256, 4096]
num_experts_range = [16, 64, 224, 256, 280, 512]
topk_range = [1, 2, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "triton"], # "triton"
line_names=["VLLM", "Triton"], # "Triton"
plot_name="moe-align-block-size-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
"""Benchmark function for Triton."""
block_size = 256
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_experts",
type=int,
default=64,
choices=[8, 16, 32, 64, 128, 256],
)
parser.add_argument(
"--topk",
type=int,
default=8,
choices=[2, 4, 8],
help="Top-k value for correctness check.",
)
args = parser.parse_args()
print("Running correctness check...")
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
benchmark.run(print_data=True, show_plots=True)
This diff is collapsed.
...@@ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,
......
...@@ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// temporarily adapted from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
m.def(
"sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton)
@pytest.mark.parametrize(
"block_size,num_tokens,topk,num_experts",
list(
itertools.product(
[32, 64, 128, 256], # block_size
[
1,
3,
7,
16,
256,
2256,
4096,
], # num_tokens
[1, 4, 16, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts
)),
)
def test_moe_align_block_size_compare_implementations(block_size, num_tokens,
topk, num_experts):
topk_ids = torch.stack([
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
])
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad_cuda = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
)
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
f"Expert IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA expert_ids: {expert_ids_cuda}\n"
f"Triton expert_ids: {expert_ids_triton}")
assert torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
f"Num tokens post pad mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}")
if __name__ == "__main__":
pytest.main([__file__])
...@@ -1524,15 +1524,6 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, ...@@ -1524,15 +1524,6 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
num_tokens_post_pad) num_tokens_post_pad)
def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
block_size, sorted_token_ids,
experts_ids, num_tokens_post_pad)
def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor],
......
...@@ -4,7 +4,6 @@ from typing import Optional ...@@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import round_up from vllm.utils import round_up
...@@ -99,6 +98,7 @@ def moe_align_block_size_stage4( ...@@ -99,6 +98,7 @@ def moe_align_block_size_stage4(
# Triton implementation based on: # Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
# TODO(wentao): Deprecated this function in the future.
def moe_align_block_size_triton( def moe_align_block_size_triton(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -220,29 +220,9 @@ def moe_align_block_size( ...@@ -220,29 +220,9 @@ def moe_align_block_size(
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
moe_align_block_size_triton( expert_ids, num_tokens_post_pad)
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment