Unverified Commit 616a3e20 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[sgl-kernel] Support moe_sum_reduce cuda kernel (#10321)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
parent ac2a723b
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
from triton.testing import do_bench from triton.testing import do_bench
...@@ -57,7 +58,7 @@ def _moe_sum_reduce_kernel( ...@@ -57,7 +58,7 @@ def _moe_sum_reduce_kernel(
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
def moe_sum_reduce( def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
): ):
assert input.is_contiguous() assert input.is_contiguous()
...@@ -117,9 +118,9 @@ def get_benchmark(): ...@@ -117,9 +118,9 @@ def get_benchmark():
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=num_tokens_range, x_vals=num_tokens_range,
line_arg="version", line_arg="version",
line_vals=["baseline", "compiled", "triton"], line_vals=["baseline", "compiled", "triton", "cuda"],
line_names=["Original", "TorchCompile", "TritonKernel"], line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
ylabel="us", ylabel="us",
plot_name="sum_scaled_performance", plot_name="sum_scaled_performance",
args={}, args={},
...@@ -140,8 +141,10 @@ def get_benchmark(): ...@@ -140,8 +141,10 @@ def get_benchmark():
compute_sum_scaled_baseline(x, out, scaling_factor) compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled": elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor) compute_sum_scaled_compiled(x, out, scaling_factor)
elif version == "triton":
moe_sum_reduce_triton(x, out, scaling_factor)
else: else:
moe_sum_reduce(x, out, scaling_factor) moe_sum_reduce_cuda(x, out, scaling_factor)
# Benchmark # Benchmark
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
...@@ -155,9 +158,15 @@ def get_benchmark(): ...@@ -155,9 +158,15 @@ def get_benchmark():
lambda: compute_sum_scaled_compiled(x, out, scaling_factor), lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles, quantiles=quantiles,
) )
elif version == "triton":
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce_triton(x, out, scaling_factor),
quantiles=quantiles,
)
else: else:
ms, min_ms, max_ms = do_bench( ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles lambda: moe_sum_reduce_cuda(x, out, scaling_factor),
quantiles=quantiles,
) )
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
...@@ -176,11 +185,16 @@ def verify_correctness(num_tokens=1024): ...@@ -176,11 +185,16 @@ def verify_correctness(num_tokens=1024):
compute_sum_scaled_compiled(x, out_compiled, scaling_factor) compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_triton = torch.empty_like(out_baseline) out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor) moe_sum_reduce_triton(x, out_triton, scaling_factor)
out_cuda = torch.empty_like(out_baseline)
moe_sum_reduce_cuda(x, out_cuda, scaling_factor)
if torch.allclose( if (
out_baseline, out_compiled, atol=1e-2, rtol=1e-2 torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2)
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2): and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match") print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
...@@ -188,6 +202,7 @@ def verify_correctness(num_tokens=1024): ...@@ -188,6 +202,7 @@ def verify_correctness(num_tokens=1024):
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
) )
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -309,6 +309,7 @@ set(SOURCES ...@@ -309,6 +309,7 @@ set(SOURCES
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_sum_reduce.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu"
......
...@@ -217,6 +217,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -217,6 +217,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce);
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <iostream>
#include <type_traits>
#include "cutlass/array.h"
#include "utils.h"
template <typename T>
__device__ __forceinline__ float to_float(T x) {
return static_cast<float>(x);
}
template <>
__device__ __forceinline__ float to_float<half>(half x) {
return __half2float(x);
}
template <typename T>
__device__ __forceinline__ T from_float(float x) {
return static_cast<T>(x);
}
template <>
__device__ __forceinline__ half from_float<half>(float x) {
return __float2half_rn(x);
}
template <typename T>
__device__ __forceinline__ T ldg_cg(const T* p) {
return __ldg(p);
}
union Pack16B {
uint4 v;
__nv_bfloat16 u16[8];
};
template <int WARPS_PER_BLOCK>
__global__ void moe_sum_reduce_warp_per_token_vec_kernel(
const at::BFloat16* __restrict__ x,
at::BFloat16* __restrict__ y,
const int64_t token_num,
const int64_t hidden_dim,
const int64_t topk_num,
const int64_t stride_token, // in elements
const int64_t stride_topk, // in elements
const int64_t out_stride_token, // in elements
const float scale) {
constexpr int VEC = 16;
constexpr int PACKS = VEC / 8;
const int warp_id = threadIdx.x / 32;
const int lane = threadIdx.x % 32;
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
if (t >= token_num) return;
const int64_t n_chunks = hidden_dim / VEC;
for (int64_t chunk = (int64_t)blockIdx.x * 32 + lane; chunk < n_chunks; chunk += (int64_t)gridDim.x * 32) {
const int64_t d = chunk * VEC;
const int64_t base = t * stride_token + d;
float acc[VEC];
#pragma unroll
for (int i = 0; i < VEC; ++i)
acc[i] = 0.f;
#pragma unroll
for (int k = 0; k < topk_num; ++k) {
#pragma unroll
for (int p = 0; p < PACKS; ++p) {
const int64_t offset = base + (int64_t)k * stride_topk + p * 8;
Pack16B pack = {ldg_cg(reinterpret_cast<const uint4*>(x + offset))};
#pragma unroll
for (int i = 0; i < 8; ++i) {
acc[p * 8 + i] += __bfloat162float(pack.u16[i]);
}
}
}
#pragma unroll
for (int i = 0; i < VEC; ++i)
acc[i] *= scale;
#pragma unroll
for (int p = 0; p < PACKS; ++p) {
Pack16B outp;
#pragma unroll
for (int i = 0; i < 8; ++i) {
outp.u16[i] = __float2bfloat16_rn(acc[p * 8 + i]);
}
const int64_t dst = t * out_stride_token + d + p * 8;
*reinterpret_cast<uint4*>(y + dst) = outp.v;
}
}
}
template <typename scalar_t, int TOPK, int WARPS_PER_BLOCK>
__global__ void moe_sum_reduce_kernel_warp_token_topk(
const scalar_t* __restrict__ x,
scalar_t* __restrict__ y,
const int64_t token_num,
const int64_t hidden_dim,
const int64_t stride_token,
const int64_t stride_topk,
const int64_t out_stride_token,
const float scale) {
const int warp_id = threadIdx.x / 32;
const int lane = threadIdx.x % 32;
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
if (t >= token_num) return;
for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) {
float acc = 0.f;
const int64_t base = t * stride_token + d;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
acc += to_float<scalar_t>(ldg_cg(&x[base + (int64_t)k * stride_topk]));
}
acc *= scale;
y[t * out_stride_token + d] = from_float<scalar_t>(acc);
}
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_reduce_kernel(
const scalar_t* __restrict__ x,
scalar_t* __restrict__ y,
const int64_t token_num,
const int64_t hidden_dim,
const int64_t stride_token,
const int64_t stride_topk,
const int64_t out_stride_token,
const float scale) {
for (int t = blockIdx.y; t < token_num; t += gridDim.y) {
for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) {
const int64_t base = t * stride_token + d;
float acc = 0.f;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
acc += to_float<scalar_t>(x[base + (int64_t)k * stride_topk]);
}
acc *= scale;
y[t * out_stride_token + d] = from_float<scalar_t>(acc);
}
}
}
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor) {
TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(input.dim() == 3, "input must be a 3D tensor like [token_num, topk_num, hidden_dim]");
TORCH_CHECK(output.dim() == 2, "output must be [token_num, hidden_dim]");
TORCH_CHECK(input.size(0) == output.size(0), "token dim mismatch");
TORCH_CHECK(input.size(2) == output.size(1), "hidden_dim mismatch");
TORCH_CHECK(input.is_contiguous(), "expect input to be contiguous");
TORCH_CHECK(output.is_contiguous(), "expect output to be contiguous");
const int64_t token_num = input.size(0);
const int64_t topk_num = input.size(1);
const int64_t hidden_dim = input.size(2);
const int64_t in_stride_token = input.stride(0);
const int64_t in_stride_topk = input.stride(1);
const int64_t out_stride_token = output.stride(0);
const float scale = static_cast<float>(routed_scaling_factor);
auto stream = at::cuda::getCurrentCUDAStream();
const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0);
// Fast path for bf16 vectorize
if (fast_bf16_vec_ok) {
constexpr int WARPS_PER_BLOCK = 8;
constexpr int THREADS = WARPS_PER_BLOCK * 32;
const int64_t n_chunks = hidden_dim / 8;
int64_t grid_x = (n_chunks + 32 - 1) / 32;
if (grid_x > 65535) grid_x = 65535;
int64_t grid_y = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK;
if (grid_y > 65535) grid_y = 65535;
dim3 block(THREADS);
dim3 grid(static_cast<unsigned>(grid_x), static_cast<unsigned>(grid_y));
auto stream = at::cuda::getCurrentCUDAStream();
moe_sum_reduce_warp_per_token_vec_kernel<WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
reinterpret_cast<const at::BFloat16*>(input.data_ptr<at::BFloat16>()),
reinterpret_cast<at::BFloat16*>(output.data_ptr<at::BFloat16>()),
token_num,
hidden_dim,
topk_num,
in_stride_token,
in_stride_topk,
out_stride_token,
scale);
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed");
return;
}
const bool per_token_use_one_warp = (token_num > 128);
auto dispatch_topk = [&](auto&& launch_kernel) {
switch (topk_num) {
case 2:
launch_kernel(std::integral_constant<int, 2>{});
break;
case 4:
launch_kernel(std::integral_constant<int, 4>{});
break;
case 8:
launch_kernel(std::integral_constant<int, 8>{});
break;
case 9:
launch_kernel(std::integral_constant<int, 9>{});
break;
default:
launch_kernel(std::integral_constant<int, -1>{});
break;
}
};
if (!per_token_use_one_warp) {
// ---------- small-token ----------
const int block_size = 256;
int64_t grid_x = (hidden_dim + block_size - 1) / block_size;
grid_x = grid_x > 65535 ? 65535 : grid_x;
int64_t grid_y = token_num < 65535 ? token_num : 65535;
dim3 block(block_size);
dim3 grid(static_cast<unsigned>(grid_x), static_cast<unsigned>(grid_y));
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] {
using scalar_t_ = scalar_t;
auto lauch_small_token_kernel = [&](auto topk_c) {
constexpr int TK = decltype(topk_c)::value;
moe_sum_reduce_kernel<scalar_t_, TK><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t_>(),
output.data_ptr<scalar_t_>(),
token_num,
hidden_dim,
in_stride_token,
in_stride_topk,
out_stride_token,
scale);
};
dispatch_topk(lauch_small_token_kernel);
});
} else {
// ---------- warp-token ----------
constexpr int WARPS_PER_BLOCK = 4;
constexpr int THREADS = WARPS_PER_BLOCK * 32;
int64_t gx = (hidden_dim + 32 - 1) / 32;
gx = gx > 65535 ? 65535 : gx;
int64_t gy = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK;
gy = gy > 65535 ? 65535 : gy;
dim3 block(THREADS);
dim3 grid(static_cast<unsigned>(gx), static_cast<unsigned>(gy));
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] {
using scalar_t_ = scalar_t;
auto launch_large_token_kernel = [&](auto topk_c) {
constexpr int TK = decltype(topk_c)::value;
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t_>(),
output.data_ptr<scalar_t_>(),
token_num,
hidden_dim,
in_stride_token,
in_stride_topk,
out_stride_token,
scale);
};
dispatch_topk(launch_large_token_kernel);
});
}
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed");
}
...@@ -293,6 +293,8 @@ void moe_align_block_size( ...@@ -293,6 +293,8 @@ void moe_align_block_size(
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize); torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize);
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor);
std::vector<at::Tensor> moe_fused_gate( std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input, at::Tensor& input,
at::Tensor& bias, at::Tensor& bias,
......
...@@ -112,6 +112,7 @@ from sgl_kernel.moe import ( ...@@ -112,6 +112,7 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm, fp8_blockwise_scaled_grouped_mm,
moe_align_block_size, moe_align_block_size,
moe_fused_gate, moe_fused_gate,
moe_sum_reduce,
prepare_moe_input, prepare_moe_input,
topk_softmax, topk_softmax,
) )
......
...@@ -36,6 +36,18 @@ def topk_softmax( ...@@ -36,6 +36,18 @@ def topk_softmax(
) )
def moe_sum_reduce(
input_tensor,
output_tensor,
routed_scaling_factor=0,
):
torch.ops.sgl_kernel.moe_sum_reduce.default(
input_tensor,
output_tensor,
routed_scaling_factor,
)
def moe_fused_gate( def moe_fused_gate(
input_tensor, input_tensor,
bias, bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment