#include #include #include #include #include #include #include #include #include "cutlass/array.h" #include "utils.h" template __device__ __forceinline__ float to_float(T x) { return static_cast(x); } template <> __device__ __forceinline__ float to_float(half x) { return __half2float(x); } template __device__ __forceinline__ T from_float(float x) { return static_cast(x); } template <> __device__ __forceinline__ half from_float(float x) { return __float2half_rn(x); } template __device__ __forceinline__ T ldg_cg(const T* p) { return __ldg(p); } union Pack16B { uint4 v; __nv_bfloat16 u16[8]; }; template __global__ void moe_sum_reduce_warp_per_token_vec_kernel( const at::BFloat16* __restrict__ x, at::BFloat16* __restrict__ y, const int64_t token_num, const int64_t hidden_dim, const int64_t topk_num, const int64_t stride_token, // in elements const int64_t stride_topk, // in elements const int64_t out_stride_token, // in elements const float scale) { constexpr int VEC = 16; constexpr int PACKS = VEC / 8; const int warp_id = threadIdx.x / 32; const int lane = threadIdx.x % 32; const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; if (t >= token_num) return; const int64_t n_chunks = hidden_dim / VEC; for (int64_t chunk = (int64_t)blockIdx.x * 32 + lane; chunk < n_chunks; chunk += (int64_t)gridDim.x * 32) { const int64_t d = chunk * VEC; const int64_t base = t * stride_token + d; float acc[VEC]; #pragma unroll for (int i = 0; i < VEC; ++i) acc[i] = 0.f; #pragma unroll for (int k = 0; k < topk_num; ++k) { #pragma unroll for (int p = 0; p < PACKS; ++p) { const int64_t offset = base + (int64_t)k * stride_topk + p * 8; Pack16B pack = {ldg_cg(reinterpret_cast(x + offset))}; #pragma unroll for (int i = 0; i < 8; ++i) { acc[p * 8 + i] += __bfloat162float(pack.u16[i]); } } } #pragma unroll for (int i = 0; i < VEC; ++i) acc[i] *= scale; #pragma unroll for (int p = 0; p < PACKS; ++p) { Pack16B outp; #pragma unroll for (int i = 0; i < 8; ++i) { outp.u16[i] = __float2bfloat16_rn(acc[p * 8 + i]); } const int64_t dst = t * out_stride_token + d + p * 8; *reinterpret_cast(y + dst) = outp.v; } } } template __global__ void moe_sum_reduce_kernel_warp_token_topk( const scalar_t* __restrict__ x, scalar_t* __restrict__ y, const int64_t token_num, const int64_t hidden_dim, const int64_t stride_token, const int64_t stride_topk, const int64_t out_stride_token, const float scale) { const int warp_id = threadIdx.x / 32; const int lane = threadIdx.x % 32; const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; if (t >= token_num) return; for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) { float acc = 0.f; const int64_t base = t * stride_token + d; #pragma unroll for (int k = 0; k < TOPK; ++k) { acc += to_float(ldg_cg(&x[base + (int64_t)k * stride_topk])); } acc *= scale; y[t * out_stride_token + d] = from_float(acc); } } template __global__ void moe_sum_reduce_kernel( const scalar_t* __restrict__ x, scalar_t* __restrict__ y, const int64_t token_num, const int64_t hidden_dim, const int64_t stride_token, const int64_t stride_topk, const int64_t out_stride_token, const float scale) { for (int t = blockIdx.y; t < token_num; t += gridDim.y) { for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { const int64_t base = t * stride_token + d; float acc = 0.f; #pragma unroll for (int k = 0; k < TOPK; ++k) { acc += to_float(x[base + (int64_t)k * stride_topk]); } acc *= scale; y[t * out_stride_token + d] = from_float(acc); } } } void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor) { TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor"); TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); TORCH_CHECK(input.dim() == 3, "input must be a 3D tensor like [token_num, topk_num, hidden_dim]"); TORCH_CHECK(output.dim() == 2, "output must be [token_num, hidden_dim]"); TORCH_CHECK(input.size(0) == output.size(0), "token dim mismatch"); TORCH_CHECK(input.size(2) == output.size(1), "hidden_dim mismatch"); TORCH_CHECK(input.is_contiguous(), "expect input to be contiguous"); TORCH_CHECK(output.is_contiguous(), "expect output to be contiguous"); const int64_t token_num = input.size(0); const int64_t topk_num = input.size(1); const int64_t hidden_dim = input.size(2); const int64_t in_stride_token = input.stride(0); const int64_t in_stride_topk = input.stride(1); const int64_t out_stride_token = output.stride(0); const float scale = static_cast(routed_scaling_factor); auto stream = at::cuda::getCurrentCUDAStream(); const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0); // Fast path for bf16 vectorize if (fast_bf16_vec_ok) { constexpr int WARPS_PER_BLOCK = 8; constexpr int THREADS = WARPS_PER_BLOCK * 32; const int64_t n_chunks = hidden_dim / 8; int64_t grid_x = (n_chunks + 32 - 1) / 32; if (grid_x > 65535) grid_x = 65535; int64_t grid_y = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; if (grid_y > 65535) grid_y = 65535; dim3 block(THREADS); dim3 grid(static_cast(grid_x), static_cast(grid_y)); auto stream = at::cuda::getCurrentCUDAStream(); moe_sum_reduce_warp_per_token_vec_kernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(output.data_ptr()), token_num, hidden_dim, topk_num, in_stride_token, in_stride_topk, out_stride_token, scale); TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed"); return; } const bool per_token_use_one_warp = (token_num > 128); auto dispatch_topk = [&](auto&& launch_kernel) { switch (topk_num) { case 2: launch_kernel(std::integral_constant{}); break; case 4: launch_kernel(std::integral_constant{}); break; case 8: launch_kernel(std::integral_constant{}); break; case 9: launch_kernel(std::integral_constant{}); break; default: launch_kernel(std::integral_constant{}); break; } }; if (!per_token_use_one_warp) { // ---------- small-token ---------- const int block_size = 256; int64_t grid_x = (hidden_dim + block_size - 1) / block_size; grid_x = grid_x > 65535 ? 65535 : grid_x; int64_t grid_y = token_num < 65535 ? token_num : 65535; dim3 block(block_size); dim3 grid(static_cast(grid_x), static_cast(grid_y)); AT_DISPATCH_FLOATING_TYPES_AND2( at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] { using scalar_t_ = scalar_t; auto lauch_small_token_kernel = [&](auto topk_c) { constexpr int TK = decltype(topk_c)::value; moe_sum_reduce_kernel<<>>( input.data_ptr(), output.data_ptr(), token_num, hidden_dim, in_stride_token, in_stride_topk, out_stride_token, scale); }; dispatch_topk(lauch_small_token_kernel); }); } else { // ---------- warp-token ---------- constexpr int WARPS_PER_BLOCK = 4; constexpr int THREADS = WARPS_PER_BLOCK * 32; int64_t gx = (hidden_dim + 32 - 1) / 32; gx = gx > 65535 ? 65535 : gx; int64_t gy = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; gy = gy > 65535 ? 65535 : gy; dim3 block(THREADS); dim3 grid(static_cast(gx), static_cast(gy)); AT_DISPATCH_FLOATING_TYPES_AND2( at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] { using scalar_t_ = scalar_t; auto launch_large_token_kernel = [&](auto topk_c) { constexpr int TK = decltype(topk_c)::value; moe_sum_reduce_kernel_warp_token_topk<<>>( input.data_ptr(), output.data_ptr(), token_num, hidden_dim, in_stride_token, in_stride_topk, out_stride_token, scale); }; dispatch_topk(launch_large_token_kernel); }); } TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed"); }