"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "6e9ce183232648858d12f6c8f4061c0e83af92d3"
Unverified Commit 57ab7769 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)

parent 112b496a
...@@ -5,7 +5,11 @@ import torch ...@@ -5,7 +5,11 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
USE_RANDOM_PERM = False USE_RANDOM_PERM = False
...@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
) )
print(f"✅ VLLM implementation works with {num_experts} experts!") print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True vllm_works = True
except RuntimeError as e: except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False vllm_works = False
...@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: ...@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
return topk_ids return topk_ids
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=False,
):
if not pad_sorted_token_ids:
sorted_ids.fill_(topk_ids.numel())
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
pad_sorted_token_ids,
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"], x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["sgl", "triton", "vllm"], line_vals=["sgl", "sgl_fusion", "triton"],
line_names=["SGL", "Triton", "VLLM"], line_names=["SGL", "SGL Fusion", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")], styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="moe-align-block-size-performance", plot_name="moe-align-block-size-performance",
...@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
sorted_ids = torch.empty( sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
) )
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty( expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
...@@ -297,35 +334,18 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -297,35 +334,18 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "sgl": if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
def sgl_moe_align_block_size_with_empty( lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
):
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
sorted_ids.clone(), sorted_ids,
expert_ids.clone(), expert_ids,
num_tokens_post_pad.clone(), num_tokens_post_pad,
token_cnts_buffer, ),
cumsum_buffer, quantiles=quantiles,
) )
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty( lambda: sgl_moe_align_block_size_with_empty(
topk_ids, topk_ids,
...@@ -334,10 +354,12 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -334,10 +354,12 @@ def benchmark(num_tokens, num_experts, topk, provider):
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
pad_sorted_token_ids=True,
), ),
quantiles=quantiles, quantiles=quantiles,
) )
elif provider == "triton": elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton( lambda: moe_align_block_size_triton(
topk_ids, topk_ids,
...@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
else: # vllm
try:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
except RuntimeError as e:
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
# Return extreme values to indicate failure in the chart
return float("inf"), float("inf"), float("inf")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/ */
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def( m.def(
......
...@@ -21,8 +21,17 @@ limitations under the License. ...@@ -21,8 +21,17 @@ limitations under the License.
#include "utils.h" #include "utils.h"
template <typename T, int N, int Alignment = sizeof(T) * N>
class alignas(Alignment) AlignedArray {
public:
T data[N];
};
#define WARP_SIZE 32 #define WARP_SIZE 32
#define VEC_SIZE 4
using Vec = AlignedArray<int32_t, VEC_SIZE>;
template <typename scalar_t> template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel( __global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
...@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel( ...@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel(
int32_t experts_per_warp, int32_t experts_per_warp,
int32_t block_size, int32_t block_size,
size_t numel, size_t numel,
int32_t* __restrict__ cumsum) { int32_t* __restrict__ cumsum,
bool pad_sorted_token_ids) {
extern __shared__ int32_t shared_counts[]; extern __shared__ int32_t shared_counts[];
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
...@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel( ...@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel(
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
} }
if (pad_sorted_token_ids) {
int32_t fill_val = static_cast<int32_t>(numel);
int32_t total = *total_tokens_post_pad;
Vec fill_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
fill_vec.data[i] = fill_val;
}
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
out_ptr[idx] = fill_vec;
}
}
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
int32_t* __restrict__ total_tokens_post_pad, int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size,
size_t numel) { size_t numel,
bool pad_sorted_token_ids) {
const size_t tid = threadIdx.x; const size_t tid = threadIdx.x;
const size_t stride = blockDim.x; const size_t stride = blockDim.x;
...@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
} }
} }
if (pad_sorted_token_ids) {
int32_t fill_val = static_cast<int32_t>(numel);
int32_t total = *total_tokens_post_pad;
Vec fill_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
fill_vec.data[i] = fill_val;
}
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
out_ptr[idx] = fill_vec;
}
}
__syncthreads();
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
...@@ -165,7 +214,8 @@ void moe_align_block_size( ...@@ -165,7 +214,8 @@ void moe_align_block_size(
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) { torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
...@@ -190,7 +240,8 @@ void moe_align_block_size( ...@@ -190,7 +240,8 @@ void moe_align_block_size(
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, num_experts,
block_size, block_size,
topk_ids.numel()); topk_ids.numel(),
pad_sorted_token_ids);
} else { } else {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
...@@ -207,7 +258,8 @@ void moe_align_block_size( ...@@ -207,7 +258,8 @@ void moe_align_block_size(
experts_per_warp, experts_per_warp,
block_size, block_size,
topk_ids.numel(), topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>(),
pad_sorted_token_ids);
const int block_threads = std::min(256, (int)threads); const int block_threads = std::min(256, (int)threads);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
......
...@@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/ */
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def( m.def(
......
...@@ -212,7 +212,8 @@ void moe_align_block_size( ...@@ -212,7 +212,8 @@ void moe_align_block_size(
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer); torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids);
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_weights,
......
...@@ -12,6 +12,7 @@ def moe_align_block_size( ...@@ -12,6 +12,7 @@ def moe_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids=False,
): ):
torch.ops.sgl_kernel.moe_align_block_size.default( torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids, topk_ids,
...@@ -22,6 +23,7 @@ def moe_align_block_size( ...@@ -22,6 +23,7 @@ def moe_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids,
) )
......
...@@ -138,33 +138,32 @@ def moe_align_block_size_triton( ...@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"block_size,num_tokens,topk,num_experts", "block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
list( list(
itertools.product( itertools.product(
[32, 64, 128, 256], # block_size [32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk [1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts [64, 160, 256, 257, 260, 264], # num_experts
[True, False], # pad_sorted_token_ids
) )
), ),
) )
def test_moe_align_block_size_compare_implementations( def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
): ):
topk_ids = torch.stack( topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
[ :, :topk
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] ]
for _ in range(num_tokens)
]
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty( sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
) )
sorted_ids_cuda.fill_(topk_ids.numel()) if not pad_sorted_token_ids:
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros( expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
...@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations( ...@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_cuda, num_tokens_post_pad_cuda,
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids,
) )
moe_align_block_size_triton( moe_align_block_size_triton(
...@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations( ...@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
) )
assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
f"Expert IDs mismatch for block_size={block_size}, " f"Expert IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n" f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA expert_ids: {expert_ids_cuda}\n" f"CUDA expert_ids: {expert_ids_cuda}\n"
f"Triton expert_ids: {expert_ids_triton}" f"Triton expert_ids: {expert_ids_triton}"
) )
assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( assert torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
), (
f"Num tokens post pad mismatch for block_size={block_size}, " f"Num tokens post pad mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n" f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}" f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
) )
# Select an expert to check
expert_idx = expert_ids_cuda.max().item()
# Get the first and last block id where expert_ids_cuda == expert_idx
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
block_sorted_start = matching_indices[0].item() * block_size
block_sorted_end = min(
(matching_indices[-1].item() + 1) * block_size, max_num_tokens_padded
)
selected_sorted_ids_cuda = sorted_ids_cuda[
block_sorted_start:block_sorted_end
].sort()[0]
selected_sorted_ids_triton = sorted_ids_triton[
block_sorted_start:block_sorted_end
].sort()[0]
assert torch.allclose(
selected_sorted_ids_cuda,
selected_sorted_ids_triton,
atol=0,
rtol=0,
), (
f"Sorted IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
f"Triton sorted_ids: {selected_sorted_ids_triton}"
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment