#include #include #include #include #include #include #include "utils.h" template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., topk, d] const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { scalar_t x = 0.0; #pragma unroll for (int k = 0; k < TOPK; ++k) { x += SGLANG_LDG(&input[token_idx * TOPK * d + k * d + idx]); } out[token_idx * d + idx] = x; } } void moe_sum( torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); const auto num_tokens = output.numel() / hidden_size; const int topk = input.size(1); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (topk) { case 2: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { moe_sum_kernel <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 3: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { moe_sum_kernel <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 4: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { moe_sum_kernel <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; default: at::sum_out(output, input, 1); break; } }