Unverified Commit 4f42c8cd authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[sgl-kernel] Support float64 moe_sum_reduce cuda kernel (#11068)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 3ddd7dc9
import os
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
from triton.testing import do_bench from triton.testing import do_bench
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
@triton.jit @triton.jit
def _moe_sum_reduce_kernel( def _moe_sum_reduce_kernel(
...@@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel( ...@@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel(
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :] base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32) accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
for i in tl.range(0, topk_num, num_stages=NUM_STAGE): for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tile = tl.load( tile = tl.load(
base_ptrs + i * input_stride_1, base_ptrs + i * input_stride_1,
...@@ -110,7 +117,7 @@ def compute_sum_scaled_compiled( ...@@ -110,7 +117,7 @@ def compute_sum_scaled_compiled(
return out return out
def get_benchmark(): def get_benchmark(dtype=torch.bfloat16):
num_tokens_range = [2**i for i in range(0, 13)] num_tokens_range = [2**i for i in range(0, 13)]
@triton.testing.perf_report( @triton.testing.perf_report(
...@@ -122,7 +129,7 @@ def get_benchmark(): ...@@ -122,7 +129,7 @@ def get_benchmark():
line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"], line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
ylabel="us", ylabel="us",
plot_name="sum_scaled_performance", plot_name=f"sum_scaled_performance_{str(dtype).split('.')[-1]}",
args={}, args={},
) )
) )
...@@ -174,8 +181,8 @@ def get_benchmark(): ...@@ -174,8 +181,8 @@ def get_benchmark():
return benchmark return benchmark
def verify_correctness(num_tokens=1024): def verify_correctness(num_tokens=1024, dtype=torch.bfloat16):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16) x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype)
scaling_factor = 0.3 scaling_factor = 0.3
out_baseline = torch.empty_like(x[:, 0]) out_baseline = torch.empty_like(x[:, 0])
...@@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024): ...@@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024):
out_compiled = torch.empty_like(out_baseline) out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor) compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_cuda = torch.empty_like(out_baseline)
moe_sum_reduce_cuda(x, out_cuda, scaling_factor)
triton_skipped = dtype == torch.float64
if not triton_skipped:
out_triton = torch.empty_like(out_baseline) out_triton = torch.empty_like(out_baseline)
moe_sum_reduce_triton(x, out_triton, scaling_factor) moe_sum_reduce_triton(x, out_triton, scaling_factor)
out_cuda = torch.empty_like(out_baseline) if dtype == torch.float64:
moe_sum_reduce_cuda(x, out_cuda, scaling_factor) atol, rtol = 1e-12, 1e-12
elif dtype == torch.float32:
atol, rtol = 1e-6, 1e-6
else: # bfloat16 / float16
atol, rtol = 1e-2, 1e-2
ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol)
ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol)
ok_triton = (
True
if triton_skipped
else torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol)
)
if ( if ok_compiled and ok_triton and ok_cuda:
torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2) msg = "✅ All implementations match"
and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2) if triton_skipped:
and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2) msg += " (Triton skipped for float64)"
): print(msg)
print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
print( print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
) )
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") if not triton_skipped:
print(
f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}"
)
print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}") print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")
if __name__ == "__main__": if __name__ == "__main__":
print("Running correctness verification...") print("Running correctness verification for bfloat16...")
verify_correctness() verify_correctness(dtype=torch.bfloat16)
# CI environment uses simplified parameters
if not IS_CI:
print("Running correctness verification for float64...")
verify_correctness(dtype=torch.float64)
print("Running correctness verification for float64...")
verify_correctness(dtype=torch.float64)
print("\nRunning performance benchmark...") print("\nRunning performance benchmark for bfloat16...")
benchmark = get_benchmark() benchmark = get_benchmark(dtype=torch.bfloat16)
benchmark.run( benchmark.run(
print_data=True, print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/" # save_path="./configs/benchmark_ops/sum_scaled/"
......
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
...@@ -12,25 +13,36 @@ ...@@ -12,25 +13,36 @@
#include "utils.h" #include "utils.h"
template <typename T> template <typename T>
__device__ __forceinline__ float to_float(T x) { using opmath_t = at::opmath_type<T>;
return static_cast<float>(x);
}
template <> template <typename T>
__device__ __forceinline__ float to_float<half>(half x) { __device__ __forceinline__ opmath_t<T> to_acc(T x) {
return __half2float(x); return static_cast<opmath_t<T>>(x);
} }
template <typename T> template <typename T>
__device__ __forceinline__ T from_float(float x) { __device__ __forceinline__ T from_acc(opmath_t<T> x) {
return static_cast<T>(x); return static_cast<T>(x);
} }
template <> template <>
__device__ __forceinline__ half from_float<half>(float x) { __device__ __forceinline__ opmath_t<at::Half> to_acc<at::Half>(at::Half x) {
return __half2float(__nv_half(x));
}
template <>
__device__ __forceinline__ at::Half from_acc<at::Half>(opmath_t<at::Half> x) {
return __float2half_rn(x); return __float2half_rn(x);
} }
template <>
__device__ __forceinline__ opmath_t<at::BFloat16> to_acc<at::BFloat16>(at::BFloat16 x) {
return __bfloat162float(__nv_bfloat16(x));
}
template <>
__device__ __forceinline__ at::BFloat16 from_acc<at::BFloat16>(opmath_t<at::BFloat16> x) {
return __float2bfloat16_rn(x);
}
template <typename T> template <typename T>
__device__ __forceinline__ T ldg_cg(const T* p) { __device__ __forceinline__ T ldg_cg(const T* p) {
return __ldg(p); return __ldg(p);
...@@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk( ...@@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk(
const int64_t stride_token, const int64_t stride_token,
const int64_t stride_topk, const int64_t stride_topk,
const int64_t out_stride_token, const int64_t out_stride_token,
const float scale) { const opmath_t<scalar_t> scale) {
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int lane = threadIdx.x % 32; const int lane = threadIdx.x % 32;
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
if (t >= token_num) return; if (t >= token_num) return;
for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) { for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) {
float acc = 0.f; opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
const int64_t base = t * stride_token + d; const int64_t base = t * stride_token + d;
#pragma unroll #pragma unroll
for (int k = 0; k < TOPK; ++k) { for (int k = 0; k < TOPK; ++k) {
acc += to_float<scalar_t>(ldg_cg(&x[base + (int64_t)k * stride_topk])); acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
} }
acc *= scale; acc *= scale;
y[t * out_stride_token + d] = from_float<scalar_t>(acc); y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
} }
} }
...@@ -139,20 +151,76 @@ __global__ void moe_sum_reduce_kernel( ...@@ -139,20 +151,76 @@ __global__ void moe_sum_reduce_kernel(
const int64_t stride_token, const int64_t stride_token,
const int64_t stride_topk, const int64_t stride_topk,
const int64_t out_stride_token, const int64_t out_stride_token,
const float scale) { const opmath_t<scalar_t> scale) {
for (int t = blockIdx.y; t < token_num; t += gridDim.y) { for (int t = blockIdx.y; t < token_num; t += gridDim.y) {
for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) {
const int64_t base = t * stride_token + d; const int64_t base = t * stride_token + d;
float acc = 0.f; opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
#pragma unroll #pragma unroll
for (int k = 0; k < TOPK; ++k) { for (int k = 0; k < TOPK; ++k) {
acc += to_float<scalar_t>(x[base + (int64_t)k * stride_topk]); acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
} }
acc *= scale; acc *= scale;
y[t * out_stride_token + d] = from_float<scalar_t>(acc); y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
}
}
}
// -------------------- general-topk fallback kernels --------------------
// small-token
template <typename scalar_t>
__global__ void moe_sum_reduce_kernel_general(
const scalar_t* __restrict__ x,
scalar_t* __restrict__ y,
const int64_t token_num,
const int64_t hidden_dim,
const int64_t stride_token,
const int64_t stride_topk,
const int64_t out_stride_token,
const int topk_num,
const opmath_t<scalar_t> scale) {
for (int t = blockIdx.y; t < token_num; t += gridDim.y) {
for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) {
const int64_t base = t * stride_token + d;
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
#pragma unroll 1
for (int k = 0; k < topk_num; ++k) {
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
}
acc *= scale;
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
}
}
}
// warp-per-token
template <typename scalar_t, int WARPS_PER_BLOCK>
__global__ void moe_sum_reduce_kernel_warp_token_general(
const scalar_t* __restrict__ x,
scalar_t* __restrict__ y,
const int64_t token_num,
const int64_t hidden_dim,
const int64_t stride_token,
const int64_t stride_topk,
const int64_t out_stride_token,
const int topk_num,
const opmath_t<scalar_t> scale) {
const int warp_id = threadIdx.x / 32;
const int lane = threadIdx.x % 32;
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
if (t >= token_num) return;
for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) {
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
const int64_t base = t * stride_token + d;
#pragma unroll 1
for (int k = 0; k < topk_num; ++k) {
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
} }
acc *= scale;
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
} }
} }
...@@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
const int64_t in_stride_topk = input.stride(1); const int64_t in_stride_topk = input.stride(1);
const int64_t out_stride_token = output.stride(0); const int64_t out_stride_token = output.stride(0);
const float scale = static_cast<float>(routed_scaling_factor);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0); const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0);
...@@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
const float scale = static_cast<float>(routed_scaling_factor);
moe_sum_reduce_warp_per_token_vec_kernel<WARPS_PER_BLOCK><<<grid, block, 0, stream>>>( moe_sum_reduce_warp_per_token_vec_kernel<WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
reinterpret_cast<const at::BFloat16*>(input.data_ptr<at::BFloat16>()), reinterpret_cast<const at::BFloat16*>(input.data_ptr<at::BFloat16>()),
reinterpret_cast<at::BFloat16*>(output.data_ptr<at::BFloat16>()), reinterpret_cast<at::BFloat16*>(output.data_ptr<at::BFloat16>()),
...@@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
out_stride_token, out_stride_token,
scale); scale);
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed"); TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (bf16 vec) launch failed");
return; return;
} }
const bool per_token_use_one_warp = (token_num > 128); const bool per_token_use_one_warp = (token_num > 128);
auto dispatch_topk = [&](auto&& launch_kernel) {
switch (topk_num) {
case 2:
launch_kernel(std::integral_constant<int, 2>{});
break;
case 4:
launch_kernel(std::integral_constant<int, 4>{});
break;
case 8:
launch_kernel(std::integral_constant<int, 8>{});
break;
case 9:
launch_kernel(std::integral_constant<int, 9>{});
break;
default:
launch_kernel(std::integral_constant<int, -1>{});
break;
}
};
if (!per_token_use_one_warp) { if (!per_token_use_one_warp) {
// ---------- small-token ---------- // ---------- small-token ----------
const int block_size = 256; const int block_size = 256;
...@@ -245,14 +292,38 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -245,14 +292,38 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
dim3 block(block_size); dim3 block(block_size);
dim3 grid(static_cast<unsigned>(grid_x), static_cast<unsigned>(grid_y)); dim3 grid(static_cast<unsigned>(grid_x), static_cast<unsigned>(grid_y));
#define LAUNCH_SMALL_TOKEN_KERNEL(TOPK) \
moe_sum_reduce_kernel<scalar_t_, TOPK><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t_>(), \
output.data_ptr<scalar_t_>(), \
token_num, \
hidden_dim, \
in_stride_token, \
in_stride_topk, \
out_stride_token, \
scale);
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] { at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] {
using scalar_t_ = scalar_t; using scalar_t_ = scalar_t;
using acc_t_ = opmath_t<scalar_t_>;
const acc_t_ scale = static_cast<acc_t_>(routed_scaling_factor);
auto lauch_small_token_kernel = [&](auto topk_c) { switch (topk_num) {
constexpr int TK = decltype(topk_c)::value; case 2:
LAUNCH_SMALL_TOKEN_KERNEL(2);
moe_sum_reduce_kernel<scalar_t_, TK><<<grid, block, 0, stream>>>( break;
case 4:
LAUNCH_SMALL_TOKEN_KERNEL(4);
break;
case 8:
LAUNCH_SMALL_TOKEN_KERNEL(8);
break;
case 9:
LAUNCH_SMALL_TOKEN_KERNEL(9);
break;
default: // launch general kernel
moe_sum_reduce_kernel_general<scalar_t_><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t_>(), input.data_ptr<scalar_t_>(),
output.data_ptr<scalar_t_>(), output.data_ptr<scalar_t_>(),
token_num, token_num,
...@@ -260,13 +331,16 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -260,13 +331,16 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
in_stride_token, in_stride_token,
in_stride_topk, in_stride_topk,
out_stride_token, out_stride_token,
static_cast<int>(topk_num),
scale); scale);
}; }
dispatch_topk(lauch_small_token_kernel);
}); });
#undef LAUNCH_SMALL_TOKEN_KERNEL
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (small-token) launch failed");
} else { } else {
// ---------- warp-token ---------- // ---------- warp-per-token ----------
constexpr int WARPS_PER_BLOCK = 4; constexpr int WARPS_PER_BLOCK = 4;
constexpr int THREADS = WARPS_PER_BLOCK * 32; constexpr int THREADS = WARPS_PER_BLOCK * 32;
...@@ -279,14 +353,38 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -279,14 +353,38 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
dim3 block(THREADS); dim3 block(THREADS);
dim3 grid(static_cast<unsigned>(gx), static_cast<unsigned>(gy)); dim3 grid(static_cast<unsigned>(gx), static_cast<unsigned>(gy));
#define LAUNCH_WARP_PER_TOKEN_KERNEL(TOPK) \
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TOPK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t_>(), \
output.data_ptr<scalar_t_>(), \
token_num, \
hidden_dim, \
in_stride_token, \
in_stride_topk, \
out_stride_token, \
scale);
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] { at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] {
using scalar_t_ = scalar_t; using scalar_t_ = scalar_t;
using acc_t_ = opmath_t<scalar_t_>;
const acc_t_ scale = static_cast<acc_t_>(routed_scaling_factor);
auto launch_large_token_kernel = [&](auto topk_c) { switch (topk_num) {
constexpr int TK = decltype(topk_c)::value; case 2:
LAUNCH_WARP_PER_TOKEN_KERNEL(2);
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>( break;
case 4:
LAUNCH_WARP_PER_TOKEN_KERNEL(4);
break;
case 8:
LAUNCH_WARP_PER_TOKEN_KERNEL(8);
break;
case 9:
LAUNCH_WARP_PER_TOKEN_KERNEL(9);
break;
default: // launch general kernel
moe_sum_reduce_kernel_warp_token_general<scalar_t_, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t_>(), input.data_ptr<scalar_t_>(),
output.data_ptr<scalar_t_>(), output.data_ptr<scalar_t_>(),
token_num, token_num,
...@@ -294,10 +392,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling ...@@ -294,10 +392,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
in_stride_token, in_stride_token,
in_stride_topk, in_stride_topk,
out_stride_token, out_stride_token,
static_cast<int>(topk_num),
scale); scale);
}; }
dispatch_topk(launch_large_token_kernel);
}); });
#undef LAUNCH_WARP_PER_TOKEN_KERNEL
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (warp-token) launch failed");
} }
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment