/** * @NOTE: This file is adapted from * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py * We: * 1. adapt from tilelang to pure cuda * 2. optimize the performance a little * 3. fix the potential illegal memory access */ #include #include #include #include #include #include #include #include #include #include namespace { constexpr int TopK = 2048; constexpr int kThreadsPerBlock = 1024; constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB struct FastTopKParams { const float* __restrict__ input; // [B, input_stride] int32_t* __restrict__ indices; // [B, TopK] int32_t* __restrict__ lengths; // [B] int64_t input_stride; }; // when length <= TopK, we can directly write the indices __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { const auto tid = threadIdx.x; for (int i = tid; i < TopK; i += kThreadsPerBlock) { indice[i] = (i < length) ? i : -1; } } // keep the first `length` entries, set others to -1 __device__ void naive_topk_transform( const float* __restrict__ score, int32_t length, int32_t* __restrict__ dst_page_table, const int32_t* __restrict__ src_page_table) { const auto tid = threadIdx.x; for (auto i = tid; i < TopK; i += kThreadsPerBlock) { dst_page_table[i] = (i < length) ? src_page_table[i] : -1; } } __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { __half h = __float2half_rn(x); uint16_t bits = __half_as_ushort(h); uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); return static_cast(key >> 8); } __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { uint32_t bits = __float_as_uint(x); return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) { // An optimized topk kernel copied from tilelang kernel // We assume length > TopK here, or it will crash int topk = TopK; constexpr auto BLOCK_SIZE = 1024; constexpr auto RADIX = 256; constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; alignas(128) __shared__ int s_counter; alignas(128) __shared__ int s_threshold_bin_id; alignas(128) __shared__ int s_num_input[2]; auto& s_histogram = s_histogram_buf[0]; // allocate for two rounds extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; const int tx = threadIdx.x; // stage 1: 8bit coarse histogram if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto bin = convert_to_uint8(input[idx]); ::atomicAdd(&s_histogram[bin], 1); } __syncthreads(); const auto run_cumsum = [&] { #pragma unroll 8 for (int i = 0; i < 8; ++i) { static_assert(1 << 8 == RADIX); if (C10_LIKELY(tx < RADIX)) { const auto j = 1 << i; const auto k = i & 1; auto value = s_histogram_buf[k][tx]; if (tx < RADIX - j) { value += s_histogram_buf[k][tx + j]; } s_histogram_buf[k ^ 1][tx] = value; } __syncthreads(); } }; run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { s_threshold_bin_id = tx; s_num_input[0] = 0; s_counter = 0; } __syncthreads(); const auto threshold_bin = s_threshold_bin_id; topk -= s_histogram[threshold_bin + 1]; if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto bin = static_cast(convert_to_uint8(input[idx])); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; } } __syncthreads(); return; } else { __syncthreads(); if (tx < RADIX + 1) { s_histogram[tx] = 0; } __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto raw_input = input[idx]; const auto bin = static_cast(convert_to_uint8(raw_input)); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { const auto pos = ::atomicAdd(&s_num_input[0], 1); /// NOTE: (dark) fuse the histogram computation here if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { s_input_idx[0][pos] = idx; const auto bin = convert_to_uint32(raw_input); const auto sub_bin = (bin >> 24) & 0xFF; ::atomicAdd(&s_histogram[sub_bin], 1); } } } __syncthreads(); } // stage 2: refine with 8bit radix passes #pragma unroll 4 for (int round = 0; round < 4; ++round) { __shared__ int s_last_remain; const auto r_idx = round % 2; // clip here to prevent overflow const auto _raw_num_input = s_num_input[r_idx]; const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { s_threshold_bin_id = tx; s_num_input[r_idx ^ 1] = 0; s_last_remain = topk - s_histogram[tx + 1]; } __syncthreads(); const auto threshold_bin = s_threshold_bin_id; topk -= s_histogram[threshold_bin + 1]; if (topk == 0) { for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; const auto offset = 24 - round * 8; const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; } } __syncthreads(); break; } else { __syncthreads(); if (tx < RADIX + 1) { s_histogram[tx] = 0; } __syncthreads(); for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; const auto raw_input = input[idx]; const auto offset = 24 - round * 8; const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { if (round == 3) { const auto pos = ::atomicAdd(&s_last_remain, -1); if (pos > 0) { index[TopK - pos] = idx; } } else { const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { /// NOTE: (dark) fuse the histogram computation here s_input_idx[r_idx ^ 1][pos] = idx; const auto bin = convert_to_uint32(raw_input); const auto sub_bin = (bin >> (offset - 8)) & 0xFF; ::atomicAdd(&s_histogram[sub_bin], 1); } } } } __syncthreads(); } } } __global__ __launch_bounds__(kThreadsPerBlock) // topk void topk_kernel(const FastTopKParams params) { const auto& [input, indices, lengths, input_stride] = params; const auto bid = static_cast(blockIdx.x); const auto length = lengths[bid]; const auto indice = indices + bid * TopK; const auto score = input + bid * input_stride; if (length <= TopK) { return naive_topk_cuda(score, indice, length); } else { return fast_topk_cuda_tl(score, indice, length); } } __global__ __launch_bounds__(kThreadsPerBlock) // decode void topk_transform_decode_kernel( const FastTopKParams params, int32_t* __restrict__ dst_page_table, const int32_t* __restrict__ src_page_table, const int64_t src_stride) { const auto& [input, _, lengths, input_stride] = params; const auto bid = static_cast(blockIdx.x); const auto tid = threadIdx.x; const auto length = lengths[bid]; const auto src_page_entry = src_page_table + bid * src_stride; const auto dst_page_entry = dst_page_table + bid * TopK; const auto score = input + bid * input_stride; if (length <= TopK) { return naive_topk_transform(score, length, dst_page_entry, src_page_entry); } else { __shared__ int s_indices[TopK]; fast_topk_cuda_tl(score, s_indices, length); // copy src[s_indices] to dst, we manually unroll here static_assert(TopK % kThreadsPerBlock == 0); static_assert(TopK / kThreadsPerBlock == 2); const auto idx_0 = tid; const auto pos_0 = s_indices[idx_0]; dst_page_entry[idx_0] = src_page_entry[pos_0]; const auto idx_1 = tid + kThreadsPerBlock; const auto pos_1 = s_indices[idx_1]; dst_page_entry[idx_1] = src_page_entry[pos_1]; } } __global__ __launch_bounds__(kThreadsPerBlock) // prefill void topk_transform_prefill_kernel( const FastTopKParams params, int32_t* __restrict__ dst_page_table, const int32_t* __restrict__ src_page_table, const int64_t src_stride, const int32_t* __restrict__ cu_seqlens_q, const int64_t prefill_bs) { const auto& [input, _, lengths, input_stride] = params; const auto bid = static_cast(blockIdx.x); const auto tid = threadIdx.x; const auto length = lengths[bid]; const auto dst_page_entry = dst_page_table + bid * TopK; const auto score = input + bid * input_stride; /// NOTE: prefill bs is usually small, we can just use a simple loop here /// We ensure that last cu_seqlens is equal to number of blocks launched __shared__ const int32_t* s_src_page_entry; if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { if (tid < prefill_bs) { if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { s_src_page_entry = src_page_table + tid * src_stride; } } } else { for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { s_src_page_entry = src_page_table + i * src_stride; } } } __syncthreads(); const auto src_page_entry = s_src_page_entry; if (length <= TopK) { return naive_topk_transform(score, length, dst_page_entry, src_page_entry); } else { __shared__ int s_indices[TopK]; fast_topk_cuda_tl(score, s_indices, length); // copy src[s_indices] to dst, we manually unroll here static_assert(TopK % kThreadsPerBlock == 0); static_assert(TopK / kThreadsPerBlock == 2); const auto idx_0 = tid; const auto pos_0 = s_indices[idx_0]; dst_page_entry[idx_0] = src_page_entry[pos_0]; const auto idx_1 = tid + kThreadsPerBlock; const auto pos_1 = s_indices[idx_1]; dst_page_entry[idx_1] = src_page_entry[pos_1]; } } auto get_params(at::Tensor score, at::Tensor lengths, std::optional indices_opt = std::nullopt) -> FastTopKParams { const auto B = score.size(0); TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); TORCH_CHECK(lengths.size(0) == B); int32_t* indices_data_ptr = nullptr; if (indices_opt.has_value()) { const auto& indices = indices_opt.value(); TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); TORCH_CHECK(indices.size(0) == B); TORCH_CHECK(indices.size(1) == TopK); indices_data_ptr = indices.data_ptr(); } return FastTopKParams{ .input = score.data_ptr(), .indices = indices_data_ptr, .lengths = lengths.data_ptr(), .input_stride = score.stride(0), }; } template void setup_kernel_smem_once() { [[maybe_unused]] static const auto result = [] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }(); TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); } } // namespace #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) { CHECK_CUDA(score); CHECK_CUDA(indices); CHECK_CUDA(lengths); const auto params = get_params(score, lengths, indices); const auto B = score.size(0); const auto stream = at::cuda::getCurrentCUDAStream().stream(); const auto grid = dim3{static_cast(B)}; const auto block = dim3{kThreadsPerBlock}; setup_kernel_smem_once(); topk_kernel<<>>(params); const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); } void fast_topk_transform_interface( at::Tensor score, at::Tensor lengths, at::Tensor dst_page_table, at::Tensor src_page_table, at::Tensor cu_seqlens_q) { CHECK_CUDA(score); CHECK_CUDA(lengths); CHECK_CUDA(dst_page_table); CHECK_CUDA(src_page_table); CHECK_CUDA(cu_seqlens_q); const auto params = get_params(score, lengths); const auto B = score.size(0); TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); const auto prefill_bs = cu_seqlens_q.size(0) - 1; TORCH_CHECK(dst_page_table.size(0) == B); TORCH_CHECK(dst_page_table.size(1) == TopK); TORCH_CHECK(src_page_table.size(0) == prefill_bs); TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs // launch kernel const auto stream = at::cuda::getCurrentCUDAStream().stream(); const auto grid = dim3{static_cast(B)}; const auto block = dim3{kThreadsPerBlock}; const auto src_stride = src_page_table.stride(0); // dispatch to decode or prefill const auto is_decode = (prefill_bs == B); if (is_decode) { setup_kernel_smem_once(); topk_transform_decode_kernel<<>>( params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); } else { setup_kernel_smem_once(); topk_transform_prefill_kernel<<>>( params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride, cu_seqlens_q.data_ptr(), prefill_bs); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); }