// SPDX-License-Identifier: MIT // // ============================================================================ // TOP-K KERNEL IMPLEMENTATION // ============================================================================ // // This file implements three adaptive strategies for efficient Top-K selection: // // 1. BlockTopkFilter - Ballot-based filtering for large, sparse datasets // - Uses __ballot() to identify and compact passing candidates // - Accumulates filtered candidates in local data share staging buffer // - Ideal when most values don't make it into Top-K // // 2. BlockTopkSort - Bitonic sort/merge for moderate datasets // - Loads capacity-sized chunks, sorts, and merges using bitonic properties // - Pure register-based, no local data share overhead // - Ideal when most values need consideration // // 3. BlockTopkMerge - Efficient merging of pre-sorted chunks // - Assumes input is already sorted in k-sized chunks // - Used for multi-block reduction phase // // GPU Optimizations Used: // - DPP (Data Parallel Primitives) for small-stride shuffles (≤8) // - Wave intrinsics (__ballot, __popcll, __shfl) for parallel operations // - Buffer load instructions for coalesced memory access // - Bitonic sort/merge leveraging wave-level parallelism // - Med3 intrinsics for branchless min/max operations // // See detailed examples and explanations inline with each strategy class. // ============================================================================ #include #include #include #include #include #include #include "dispatch_utils.h" #include "py_itfs_common.h" #include "warp_sort.h" #include "quick_all_reduce_base.h" #define HIP_CHECK(val) \ { \ utils::hip_check_((val), __FILE__, __LINE__); \ } // Forward declaration of topk_per_row kernel from topk_per_row_kernels.cu namespace aiter { // Phase enum for distinguishing prefill vs decode paths enum class Phase { Prefill, Decode, }; template __global__ void topk_per_row(const float* logits, const int* rowStarts, const int* rowEnds, int* outIndices, int stride0, int stride1, int rowOffset); // Forward declaration of standalone_stable_radix_11bits from topk_per_row_kernels.cu template void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, int64_t len, IdxT* rowStarts, IdxT* rowEnds, IdxT k, T* out, IdxT* out_idx, bool greater, hipStream_t stream, int next_n = 0); } // namespace aiter // Forward declaration of workspace size calculation function (at global scope) template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); extern template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); // Forward declaration of helper function to call topk_per_row kernel template void topk_per_row_kernel_launcher(const float* in, const IdxT* rowStarts, const IdxT* rowEnds, IdxT* out_idx, const float* out, int batch_size, int stride0, int stride1, int k, hipStream_t stream); // Helper function to determine if topk_per_row kernel should be used // Based on: n + K log²K ≥ 3 × Factor(n) × n // where Factor(n) = 1/3 + 1.6/(log₂(n) - 9.5) // Simplifies to: K log²K ≥ 4.8n/(log₂(n) - 9.5) // TODO: We need to confirm whether, when n <= 2048, we might choose // radix sort because the denominator becomes very small; does that // still yield the best performance? template __forceinline__ __host__ bool should_use_topk_radix(IdxT len, IdxT k) { const double n = static_cast(len); const double K = static_cast(k); if(K <= 1.0) { return false; } const double log_n = std::log2(n); const double denom = std::max(0.0001, log_n - 9.5); const double rhs = (4.8 * n) / denom; const double log_k = std::log2(K); const double lhs = K * log_k * log_k; return lhs >= rhs; } // Gather kernel to extract values based on indices (uniform length) template __global__ void gather_topk_values_kernel(const T* __restrict__ in, const IdxT* __restrict__ indices, T* __restrict__ out, int batch_size, int len, int k) { int batch_id = blockIdx.x; if(batch_id >= batch_size) return; const T* in_row = in + batch_id * len; const IdxT* idx_row = indices + batch_id * k; T* out_row = out + batch_id * k; for(int i = threadIdx.x; i < k; i += blockDim.x) { IdxT idx = idx_row[i]; if(idx >= 0 && idx < len) { out_row[i] = in_row[idx]; } } } // Gather kernel for variable length with strides template __global__ void gather_topk_values_strided_kernel(const T* __restrict__ in, const IdxT* __restrict__ indices, T* __restrict__ out, const IdxT* __restrict__ rowStarts, int batch_size, int stride0, int stride1, int k) { int batch_id = blockIdx.x; if(batch_id >= batch_size) return; IdxT start = rowStarts[batch_id]; const T* in_row = in + batch_id * stride0; const IdxT* idx_row = indices + batch_id * k; T* out_row = out + batch_id * k; for(int i = threadIdx.x; i < k; i += blockDim.x) { IdxT idx = idx_row[i]; if(idx >= 0) { // idx is relative to rowStart, need to add start and apply stride1 out_row[i] = in_row[(start + idx) * stride1]; } } } namespace topk { // ============================================================================ // TYPE TRAITS FOR DATA/COMPUTE TYPE SEPARATION // ============================================================================ // // Design Philosophy: // - DataType (DataT): The storage/I/O type for memory operations // - ComputeType (ComputeT): The type used for internal computations // // Mapping: // - fp16, bf16, float -> compute as float (better precision, consistent ops) // - int -> compute as int // // This separation allows: // 1. Memory-efficient storage with compact types (fp16, bf16) // 2. High-precision computation with float // 3. Easy extension for new types (e.g., fp8, int8) // // Usage: // using ComputeT = compute_t; // ComputeT val = type_convert::to_compute(data_val); // DataT result = type_convert::to_data(compute_val); // ============================================================================ namespace type_traits { // Primary template: maps DataType -> ComputeType template struct ComputeTypeTraits { static_assert(sizeof(DataT) == 0, "ComputeTypeTraits not specialized for this type. " "Supported types: _Float16, __bf16, float, int"); }; // Specializations for floating-point types -> float template <> struct ComputeTypeTraits<_Float16> { using type = float; }; template <> struct ComputeTypeTraits { using type = float; }; template <> struct ComputeTypeTraits<__bf16> { using type = float; }; template <> struct ComputeTypeTraits { using type = float; }; // Specialization for integer types -> int template <> struct ComputeTypeTraits { using type = int; }; // Convenience alias template using compute_t = typename ComputeTypeTraits::type; } // namespace type_traits // Bring compute_t into topk namespace for convenience using type_traits::compute_t; // ============================================================================ // TYPE CONVERSION UTILITIES // ============================================================================ namespace type_convert { // Convert from DataType to ComputeType template __device__ __host__ __forceinline__ type_traits::compute_t to_compute(DataT val) { return static_cast>(val); } // Convert from ComputeType to DataType template __device__ __host__ __forceinline__ DataT to_data(type_traits::compute_t val) { return static_cast(val); } } // namespace type_convert namespace utils { // Supported types (for validation) template struct is_supported_type { static constexpr bool value = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; }; template inline constexpr bool is_supported_type_v = is_supported_type::value; class HipException : public std::runtime_error { public: explicit HipException(const std::string& what) : runtime_error(what) {} }; inline void hip_check_(hipError_t val, const char* file, int line) { if(val != hipSuccess) { throw HipException(std::string(file) + ":" + std::to_string(line) + ": HIP error " + std::to_string(val) + ": " + hipGetErrorString(val)); } } /** * @brief Rounds a value up to the nearest multiple of a given number. * * This implementation uses integer arithmetic and works for any multiple, * not just powers of two. * * @tparam Multiple The multiple to round up to. * @tparam T The integer type of the value. * @param value The value to round up. * @return The smallest multiple of `Multiple` that is greater than or equal to `value`. */ template __inline__ __host__ __device__ constexpr T round_up_to_multiple_of(T value) { if(value == 0) { return 0; } static_assert(Multiple > 0, "Multiple must be positive."); return ((value - 1) / Multiple + 1) * Multiple; } /** * @brief Rounds a value up to the nearest multiple of a given number. * * This implementation uses integer arithmetic and works for any multiple, * not just powers of two. * * @tparam T The integer type of the value. * @param value The value to round up. * @param Multiple The multiple to round up to. * @return The smallest multiple of `Multiple` that is greater than or equal to `value`. */ template __inline__ __host__ __device__ constexpr T round_up_to_multiple_of(T value, size_t multiple) { return value > 0 ? ((value - 1) / multiple + 1) * multiple : 0; } /** * @brief Checks if an integer is a power of two. * * This uses the classic and highly efficient bitwise trick. * * @tparam T An unsigned integer type. * @param value The value to check. * @return True if `value` is a power of two, false otherwise. */ template __inline__ __host__ __device__ constexpr bool is_power_of_2(T value) { // static_assert(std::is_unsigned::value, "is_power_of_2 works best with unsigned types."); return (value && !(value & (value - 1))); } /** * @brief Calculates the smallest power of two not less than the given value. * * This function is also known as "ceil to power of 2". It uses a fast, * non-recursive bit-twiddling algorithm. * * @tparam T An unsigned integer type. * @param value The value to round up. * @return The smallest power of two >= `value`. Returns 1 for an input of 0. */ template __inline__ __host__ __device__ constexpr T ceil_to_power_of_2(T value) { // static_assert(std::is_unsigned::value, "ceil_to_power_of_2 works best with unsigned // types."); if(value <= 1) { return 1; } // A fast bit-twiddling algorithm to find the next power of two. // It works by smearing the highest set bit to all lower bits. T v = value - 1; // The number of shifts depends on the type size. We can be exhaustive. v |= v >> 1; v |= v >> 2; v |= v >> 4; if constexpr(sizeof(T) >= 2) v |= v >> 8; if constexpr(sizeof(T) >= 4) v |= v >> 16; if constexpr(sizeof(T) >= 8) v |= v >> 32; return v + 1; } /** * @brief Calculates the integer base-2 logarithm of a number, rounded down. * * This is a portable, recursive constexpr implementation. For performance-critical * host code, compiler intrinsics like `__builtin_clz` or C++20's `` * header are often faster. * * @tparam T An integer type. * @param n The input number. * @param p Internal counter for recursion. * @return The value of floor(log2(n)). */ template __inline__ __host__ __device__ constexpr int integer_log2(T n, int p = 0) { return (n <= 1) ? p : integer_log2(n / 2, p + 1); } __inline__ __host__ __device__ constexpr int calc_capacity(int k) { int capacity = utils::ceil_to_power_of_2(k); return (capacity < ck_tile::get_warp_size()) ? ck_tile::get_warp_size() : capacity; } } // namespace utils namespace numeric { // ============================================================================ // BOUNDS AND SENTINEL VALUES // ============================================================================ // These functions now work with ComputeType for internal operations. // The sentinel values are defined in ComputeType space (float for floating-point // DataTypes, int for integer DataTypes). // ============================================================================ /** * @brief Gets the absolute lowest possible value for a compute type. * * Uses -infinity for floating-point compute types, and the lowest finite * value for integer compute types. * * @tparam ComputeT The compute type (float or int). */ template __inline__ __device__ __host__ constexpr ComputeT get_lower_bound() { if constexpr(std::is_same_v) { return -std::numeric_limits::infinity(); } else if constexpr(std::is_same_v) { return std::numeric_limits::lowest(); } else { static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } /** * @brief Gets the absolute highest possible value for a compute type. * * Uses +infinity for floating-point compute types, and the maximum finite * value for integer compute types. * * @tparam ComputeT The compute type (float or int). */ template __inline__ __device__ __host__ constexpr ComputeT get_upper_bound() { if constexpr(std::is_same_v) { return std::numeric_limits::infinity(); } else if constexpr(std::is_same_v) { return std::numeric_limits::max(); } else { static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } /** * @brief Gets a sentinel value for a search algorithm (e.g., Top-K). * * The sentinel is defined in ComputeType space. For finding the largest values, * we use the lowest possible value as sentinel (so any real value will be preferred). * For finding the smallest values, we use the highest possible value. * * @tparam FindLargest If true, returns lowest value. If false, returns highest value. * @tparam ComputeT The compute type (float or int). */ template __inline__ __device__ __host__ constexpr ComputeT get_sentinel_value() { if constexpr(FindLargest) { return get_lower_bound(); } else { return get_upper_bound(); } } /** * @brief Gets sentinel value based on DataType (converts to appropriate ComputeType). * * This is a convenience overload that deduces the ComputeType from DataType. * * @tparam FindLargest If true, returns lowest value. If false, returns highest value. * @tparam DataT The data type (fp16, bf16, float, int). */ template __inline__ __device__ __host__ constexpr compute_t get_sentinel_value_for_data() { return get_sentinel_value>(); } /** * @brief A generic comparison function for search algorithms. * * Compares `val` against `baseline` according to the search direction * specified by the `FindLargest` template parameter. * Works with ComputeType values. * * @tparam FindLargest If true, checks if `val` is greater than `baseline`. * If false, checks if `val` is less than `baseline`. * @tparam ComputeT The compute type (float or int). * @param val The new value to check. * @param baseline The current best value. * @return True if `val` is "preferred" over `baseline`. */ template __device__ __host__ __forceinline__ constexpr bool is_preferred(ComputeT val, ComputeT baseline) { if constexpr(FindLargest) { return val > baseline; } else { return val < baseline; } } } // namespace numeric namespace sorting { // ============================================================================ // SORTING OPERATIONS (Work with ComputeType) // ============================================================================ // All sorting operations in this namespace work with ComputeType values. // The template parameter T should be the compute type (float or int). // The idxT parameter is the index type (typically int32_t). // // The sorting algorithms use: // - DPP (Data Parallel Primitives) for small-stride shuffles (≤8) // - Wave intrinsics (__ballot, __popcll, __shfl) for larger operations // - Bitonic sort/merge for efficient parallel sorting // ============================================================================ template struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { static_assert(utils::is_power_of_2(size)); static_assert(size >= 2 * ck_tile::get_warp_size()); constexpr int arr_len = size / ck_tile::get_warp_size(); constexpr int stride = arr_len / 2; for(int i = 0; i < stride; ++i) { const int other_i = i + stride; T& val = val_arr[i]; T& other_val = val_arr[other_i]; if((val > other_val && ascending) || (val < other_val && !ascending)) { T tmp = val; val = other_val; other_val = tmp; idxT tmp2 = idx_arr[i]; idx_arr[i] = idx_arr[other_i]; idx_arr[other_i] = tmp2; } } BitonicMerge::merge(val_arr, idx_arr); BitonicMerge::merge(val_arr + arr_len / 2, idx_arr + arr_len / 2); } }; template struct BitonicSort { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { static_assert(utils::is_power_of_2(size)); static_assert(size >= 2 * ck_tile::get_warp_size()); constexpr int arr_len = size / ck_tile::get_warp_size(); BitonicSort::sort(val_arr, idx_arr); BitonicSort::sort(val_arr + arr_len / 2, idx_arr + arr_len / 2); BitonicMerge::merge(val_arr, idx_arr); } }; template __device__ __forceinline__ idxT select_idx( const idxT& idx_a, const idxT& idx_b, const T& val_a, const T& val_b, const T& selected_val) { return (selected_val == val_a) ? idx_a : idx_b; } template struct StrideToDPP { static_assert(stride == 1 || stride == 2 || stride == 4 || stride == 8, "DPP only supports stride 1 ,2, 4, 8"); }; template <> struct StrideToDPP<1> { static constexpr int dpp_i = 0xb1; // quad_perm: [1,0,3,2] }; template <> struct StrideToDPP<2> { static constexpr int dpp_i = 0x4e; // quad_perm: [2,3,0,1] }; template <> struct StrideToDPP<4> { static constexpr int dpp_i_shl = 260; static constexpr int bank_mask_shl = 0b0101; static constexpr int dpp_i_shr = 276; static constexpr int bank_mask_shr = 0b1010; }; template <> struct StrideToDPP<8> { static constexpr int dpp_i_shl = 264; static constexpr int bank_mask_shl = 0b0011; static constexpr int dpp_i_shr = 280; static constexpr int bank_mask_shr = 0b1100; }; template __forceinline__ __device__ T mov_dpp(T x) { static_assert(sizeof(T) == 4 || sizeof(T) == 2, "mov_dpp only supports 32-bit and 16-bit types."); constexpr int dpp_i = StrideToDPP::dpp_i; constexpr int row_mask = 0xf; constexpr int bank_mask = 0xf; constexpr bool bound_ctrl = true; // Returns own value if source is out of bounds if constexpr(sizeof(T) == 4) { return aiter::mov_dpp_(x, ck_tile::number(), ck_tile::number(), ck_tile::number(), ck_tile::bool_constant()); } else if constexpr(sizeof(T) == 2) { unsigned short x_u16 = __builtin_bit_cast(unsigned short, x); unsigned int x_u32 = x_u16; unsigned int result_u32 = __builtin_amdgcn_mov_dpp(x_u32, dpp_i, row_mask, bank_mask, bound_ctrl); unsigned short result_u16 = static_cast(result_u32); return __builtin_bit_cast(T, result_u16); } } template __forceinline__ __device__ T upd_dpp(const T& old, T x) { static_assert(sizeof(T) == 4 || sizeof(T) == 2, "upd_dpp only supports 32-bit and 16-bit types."); constexpr int dpp_i = shl ? StrideToDPP::dpp_i_shl : StrideToDPP::dpp_i_shr; constexpr int row_mask = 0xf; constexpr int bank_mask = shl ? StrideToDPP::bank_mask_shl : StrideToDPP::bank_mask_shr; constexpr bool bound_ctrl = true; if constexpr(sizeof(T) == 4) { return aiter::upd_dpp_(old, x, ck_tile::number(), ck_tile::number(), ck_tile::number(), ck_tile::bool_constant()); } else if constexpr(sizeof(T) == 2) { unsigned int old_u32 = __builtin_bit_cast(unsigned short, old); unsigned int x_u32 = __builtin_bit_cast(unsigned short, x); unsigned int result_u32 = __builtin_amdgcn_update_dpp(old_u32, x_u32, dpp_i, row_mask, bank_mask, bound_ctrl); unsigned short result_u16 = static_cast(result_u32); return __builtin_bit_cast(T, result_u16); } } // Helper function to perform shuffle based on type template __forceinline__ __device__ T shfl_xor(T val, int stride) { if constexpr(sizeof(T) == 4) { return __builtin_bit_cast(T, __shfl_xor(__builtin_bit_cast(int, val), stride)); } else if constexpr(sizeof(T) == 8) { return __builtin_bit_cast(T, __shfl_xor(__builtin_bit_cast(long long, val), stride)); } else if constexpr(sizeof(T) == 2) { // 16-bit types (_Float16, __bf16) unsigned int val_u32 = __builtin_bit_cast(unsigned short, val); unsigned int result_u32 = __shfl_xor(val_u32, stride); unsigned short result_u16 = static_cast(result_u32); return __builtin_bit_cast(T, result_u16); } else { static_assert(sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8, "shfl_xor only supports 16-bit, 32-bit, and 64-bit types."); __builtin_unreachable(); } } /** * @brief Gets guard value for bitonic sort comparisons. * * This function returns boundary values used in bitonic sorting. * Works with ComputeType (float or int). * * @tparam ComputeT The compute type (float or int). * @param x If true, returns lowest value; if false, returns highest value. */ template __forceinline__ __device__ constexpr ComputeT get_guard(const bool x) { if constexpr(std::is_same_v) { return x ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); } else if constexpr(std::is_same_v) { return x ? std::numeric_limits::lowest() : std::numeric_limits::max(); } else { static_assert(sizeof(ComputeT) == 0, "get_guard only supports float and int compute types"); __builtin_unreachable(); } } // Optimized sort step using DPP for small strides template __forceinline__ __device__ typename std::enable_if<(stride <= 2), void>::type sort_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool reverse = (lane >> stage) & 2; bool is_second = lane & stride; const auto val = *val_arr; const auto idx = *idx_arr; T other = mov_dpp(val); idxT other_idx = mov_dpp(idx); // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(reverse != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); *val_arr = selected_val; *idx_arr = selected_idx; } // Optimized sort step using DPP for small strides template __forceinline__ __device__ typename std::enable_if<(stride > 2 && stride <= 8), void>::type sort_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool reverse = (lane >> stage) & 2; bool is_second = lane & stride; const auto val = *val_arr; const auto idx = *idx_arr; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" T other; other = upd_dpp(other, val); other = upd_dpp(other, val); idxT other_idx; other_idx = upd_dpp(other_idx, idx); other_idx = upd_dpp(other_idx, idx); #pragma clang diagnostic pop // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(reverse != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); *val_arr = selected_val; *idx_arr = selected_idx; } // Fallback to shuffle for larger strides template __forceinline__ __device__ typename std::enable_if<(stride > 8), void>::type sort_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool reverse = (lane >> stage) & 2; bool is_second = lane & stride; const auto val = *val_arr; const auto idx = *idx_arr; T other = shfl_xor(val, stride); idxT other_idx = shfl_xor(idx, stride); // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(reverse != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); *val_arr = selected_val; *idx_arr = selected_idx; } template struct BitonicSort<64, ascending, T, idxT> { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { // Stage 0: stride = 1 (DPP optimized) sort_step(val_arr, idx_arr); // Stage 1: stride = 2, 1 (DPP optimized) sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); // Stage 2: stride = 4, 2, 1 (DPP optimized) sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); // Stage 3: stride = 8, 4, 2, 1 (DPP optimized) sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); sort_step(val_arr, idx_arr); // Stage 4: stride = 16, 8, 4, 2, 1 sort_step(val_arr, idx_arr); // Uses shuffle sort_step(val_arr, idx_arr); // Uses DPP sort_step(val_arr, idx_arr); // Uses DPP sort_step(val_arr, idx_arr); // Uses DPP sort_step(val_arr, idx_arr); // Uses DPP BitonicMerge<64, ascending, T, idxT>::merge(val_arr, idx_arr); } }; // Optimized merge using DPP for small strides template __forceinline__ __device__ typename std::enable_if<(stride <= 2), void>::type merge_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool is_second = lane & stride; T& val = *val_arr; idxT& idx = *idx_arr; T other = mov_dpp(val); idxT other_idx = mov_dpp(idx); // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(ascending != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); val = selected_val; idx = selected_idx; } // Optimized sort step using DPP for small strides template __forceinline__ __device__ typename std::enable_if<(stride > 2 && stride <= 8), void>::type merge_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool is_second = lane & stride; T& val = *val_arr; idxT& idx = *idx_arr; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" T other; other = upd_dpp(other, val); other = upd_dpp(other, val); idxT other_idx; other_idx = upd_dpp(other_idx, idx); other_idx = upd_dpp(other_idx, idx); #pragma clang diagnostic pop // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(ascending != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); val = selected_val; idx = selected_idx; } // Fallback to shuffle for larger strides template __forceinline__ __device__ typename std::enable_if<(stride > 8), void>::type merge_step(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { const int lane = threadIdx.x & (ck_tile::get_warp_size() - 1); bool is_second = lane & stride; T& val = *val_arr; idxT& idx = *idx_arr; T other = shfl_xor(val, stride); idxT other_idx = shfl_xor(idx, stride); // Use median-of-3 to select the appropriate value T selected_val = aiter::dev_med3_(val, other, get_guard(ascending != is_second)); idxT selected_idx = select_idx(idx, other_idx, val, other, selected_val); val = selected_val; idx = selected_idx; } template struct BitonicMerge<64, ascending, T, idxT> { __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { merge_step(val_arr, idx_arr); // Shuffle merge_step(val_arr, idx_arr); // Shuffle merge_step(val_arr, idx_arr); // DPP merge_step(val_arr, idx_arr); // DPP merge_step(val_arr, idx_arr); // DPP merge_step(val_arr, idx_arr); // DPP } }; } // namespace sorting namespace buffer_load_helpers { constexpr int MAX_CAPACITY = 2048; using int32x4_t = int __attribute__((ext_vector_type(4))); using floatx4_t = float __attribute__((ext_vector_type(4))); using bf16x8_t = __bf16 __attribute__((ext_vector_type(8))); using halfx8_t = _Float16 __attribute__((ext_vector_type(8))); using index_t = uint32_t; __device__ __forceinline__ static int32x4_t asm_buffer_load_dwordx4(int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); template __device__ __forceinline__ VecType buffer_load_dwordx4(int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) { return __builtin_bit_cast(VecType, asm_buffer_load_dwordx4(srsrc, voffset, soffset, aux)); } } // namespace buffer_load_helpers // --- Wave-Level Priority Selection Primitives (HYGON/HIP Optimized) --- // // THREE STRATEGIES FOR TOP-K SELECTION: // // 1. WaveTopkFilter // - Uses ballot-based filtering to skip irrelevant candidates // - Best for: Large datasets where len_per_wave > capacity × 4 // - Uses local data share for staging // - Example: Finding top 100 from 1 million elements (most filtered out) // // 2. WaveTopkSort // - Processes data in capacity-sized batches with bitonic sort // - Best for: Moderate datasets where len_per_wave ≤ capacity × 4 // - Register-only, no local data share // - Example: Finding top 100 from 10,000 elements // // 3. WaveTopkMerge // - Merges pre-sorted k-sized chunks iteratively // - Best for: Multi-block reduction (merging results from multiple blocks) // - Used in second pass when first pass produces multiple results // - Example: Combining top-100 results from 8 different blocks // // Selection logic: // - Compute len_per_wave based on launch configuration // - If len_per_wave ≤ capacity × 4: Use BlockTopkSort // - If len_per_wave > capacity × 4: Use BlockTopkFilter // - For multi-block reduction: Always use BlockTopkMerge // template struct WaveTopkFilter; template struct WaveTopkSort; template struct WaveTopkMerge; template struct BlockTopkFilter; template struct BlockTopkSort; template struct BlockTopkMerge; // ============================================================================ // WAVE BUFFER (Stores priorities in ComputeType) // ============================================================================ // // WaveBuffer manages per-wave register storage for priority candidates. // Key design: // - DataT: The I/O type for loading/storing data // - ComputeT: The internal type for priorities (float or int) // - Priorities are stored as ComputeType for consistent computation // - Conversion happens at I/O boundaries // // Template parameters: // - capacity: Power-of-2 buffer capacity (>= wave size) // - DataT: Data type for I/O (fp16, bf16, float, int) // - IdxT: Index type (typically int32_t) // ============================================================================ template struct WaveBuffer { using ComputeT = compute_t; static constexpr int slots_per_lane = capacity / ck_tile::get_warp_size(); static_assert(capacity >= ck_tile::get_warp_size() && utils::is_power_of_2(capacity), "Capacity must be power-of-2 and >= wave size"); ComputeT priorities[slots_per_lane]; IdxT positions[slots_per_lane]; int lane_id; IdxT target_count; ComputeT sentinel; __device__ WaveBuffer(IdxT k, ComputeT sentinel_value) : lane_id(threadIdx.x & (ck_tile::get_warp_size() - 1)), target_count(k), sentinel(sentinel_value) { #pragma unroll for(int i = 0; i < slots_per_lane; ++i) { priorities[i] = sentinel; } } __device__ inline void reset_slot(int slot, ComputeT val = {}, IdxT pos = {}) { priorities[slot] = val; positions[slot] = pos; } // Flush results to output buffer // OutT can be DataT (for final output) or ComputeT (for LDS operations) template __device__ inline void flush_results(OutT* __restrict__ out_vals, IdxT* __restrict__ out_indices) const { #pragma unroll for(int i = 0; i < slots_per_lane; ++i) { const IdxT global_slot = i * ck_tile::get_warp_size() + lane_id; if(global_slot < target_count) { out_vals[global_slot] = static_cast(priorities[i]); out_indices[global_slot] = positions[i]; } } } }; // Helper for merging sorted sequences (used by multiple strategies) // Works with ComputeType internally, reads from ComputeType buffers template struct WaveMergeHelper { using ComputeT = compute_t; // Merges a sorted k-element chunk with the buffer's existing Top-K // Input is in ComputeType (from LDS or previous computation) // EXAMPLE (finding Top-4 largest, capacity=64, k=4): // Wave-distributed storage (64 lanes, each lane holds slots_per_lane=1 value): // Lanes 0-3: [80, 85, 90, 95] (current top-4, in ascending order) // Lanes 4-63: [-∞, -∞, ...] (sentinels) // Input chunk: in[start+0]=65, in[start+1]=70, in[start+2]=75, in[start+3]=100 // // Element-wise comparison (reads input in reverse, idx = start + 63 - lane_id): // Lane 0: idx=start+63 (out of range, skip) // ... // Lane 60: idx=start+3, reads 65, compares with -∞ → update to 65 // Lane 61: idx=start+2, reads 70, compares with -∞ → update to 70 // Lane 62: idx=start+1, reads 75, compares with -∞ → update to 75 // Lane 63: idx=start+0, reads 100, compares with -∞ → update to 100 // // After element-wise updates (before merge): // Lanes: [80,85,90,95, -∞,...,-∞, 65,70,75,100] // ↑ lanes 0-3 ↑lanes 4-59 ↑lanes 60-63 // // BitonicMerge (ascending) redistributes across all lanes: // Result: [-∞,...,-∞, 65,70,75,80,85,90,95,100] // ↑lanes 0-55 ↑──── lanes 56-63 ────↑ // // Extract top-k=4 (last 4 in ascending order): // Lanes 60-63 now contain: [85, 90, 95, 100] __device__ static void merge_sorted_range(WaveBuffer& buffer, const ComputeT* __restrict__ in, const IdxT* __restrict__ in_idx, IdxT start) { IdxT idx = start + ck_tile::get_warp_size() - 1 - buffer.lane_id; #pragma unroll for(int i = buffer.slots_per_lane - 1; i >= 0; --i, idx += ck_tile::get_warp_size()) { if(idx < start + buffer.target_count) { ComputeT candidate = in[idx]; if(numeric::is_preferred(candidate, buffer.priorities[i])) { buffer.priorities[i] = candidate; buffer.positions[i] = in_idx[idx]; } } } sorting::BitonicMerge::merge(buffer.priorities, buffer.positions); } }; // Forward declarations for kernel wrapper functions // Note: Kernels use DataT for I/O and compute_t for sentinel/internal computation template __global__ void __launch_bounds__(512, 2) topk_filter_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, DataT* __restrict__ out, IdxT* __restrict__ out_idx, compute_t sentinel); template __global__ void __launch_bounds__(512, 2) topk_sort_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, DataT* __restrict__ out, IdxT* __restrict__ out_idx, compute_t sentinel); template __global__ void __launch_bounds__(512, 2) topk_merge_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, DataT* __restrict__ out, IdxT* __restrict__ out_idx, compute_t sentinel); template using KernelFuncPtr = void (*)(const DataT*, const IdxT*, int, IdxT, IdxT, DataT*, IdxT*, compute_t); // Helper: Map block-level strategy class to its corresponding kernel function template // UseBufferAddressing: Controls whether BlockTopkFilter uses buffer addressing (limited to // UINT_MAX) template