Unverified Commit b55d830e authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with...


[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode (#37421)
Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 75e01a39
......@@ -18,10 +18,9 @@ steps:
source_file_dependencies:
- csrc/
- tests/kernels/core
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py
commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
- pytest -v -s kernels/core kernels/test_concat_mla_q.py
- label: Kernels Attention Test %N
timeout_in_minutes: 35
......@@ -107,6 +106,7 @@ steps:
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/v1/attention/selector.py
- vllm/platforms/cuda.py
- tests/kernels/test_top_k_per_row.py
commands:
- nvidia-smi
- python3 examples/basic/offline_inference/chat.py
......@@ -117,6 +117,7 @@ steps:
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
- pytest -v -s tests/kernels/test_top_k_per_row.py
# Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
......
......@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
const torch::Tensor& lengths,
std::optional<torch::Tensor> row_starts_opt);
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
......
/*
* Persistent TopK Scheduler for DSA Indexer
*/
#ifndef PERSISTENT_TOPK_CUH_
#define PERSISTENT_TOPK_CUH_
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <cstdint>
namespace vllm {
namespace persistent {
// ============================================================================
// Constants
// ============================================================================
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr int RADIX = 256;
// Medium path: all shared state in dynamic smem (no static __shared__,
// which would inflate the kernel's smem footprint and kill occupancy
// for the decode/trivial paths).
constexpr size_t kMediumHistBytes = 2 * (RADIX + 128) * sizeof(int); // 3072
constexpr size_t kMediumScalarsBytes = 5 * sizeof(int); // 20
constexpr size_t kMediumHeaderSize =
(kMediumHistBytes + kMediumScalarsBytes + 127) & ~size_t(127); // 3200
constexpr int MAX_BUFFERED_ITEMS = 4096;
constexpr size_t kSmemMedium =
kMediumHeaderSize + 2 * MAX_BUFFERED_ITEMS * sizeof(int); // 35968
constexpr uint32_t RADIX_THRESHOLD = 32768;
// Decode path constants
constexpr int kDecodeBins = 2048;
constexpr uint32_t HIST2048_THRESHOLD = 8192;
// Large path: fixed shared memory for histograms + scalars
constexpr size_t kFixedSmemLarge =
((RADIX + RADIX + 5) * sizeof(uint32_t) + 15) & ~size_t(15);
// ============================================================================
// Common helpers
// ============================================================================
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
// ============================================================================
// Vectorized load helpers
// ============================================================================
// Unconditional float4 load with cache hint (.cg = cache at global level only).
__device__ __forceinline__ void load_float4(const float* ptr, float& v0,
float& v1, float& v2, float& v3) {
uint32_t r0, r1, r2, r3;
asm volatile("ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%4];\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "l"(ptr));
v0 = __uint_as_float(r0);
v1 = __uint_as_float(r1);
v2 = __uint_as_float(r2);
v3 = __uint_as_float(r3);
}
// Per-element predicated scalar loads with -inf default.
__device__ __forceinline__ void load_float4_predicated(const float* ptr,
int base, int seq_len,
float& v0, float& v1,
float& v2, float& v3) {
uint32_t r0, r1, r2, r3;
int p0 = (base < seq_len);
int p1 = (base + 1 < seq_len);
int p2 = (base + 2 < seq_len);
int p3 = (base + 3 < seq_len);
asm volatile(
"{\n"
" .reg .pred pr0, pr1, pr2, pr3;\n"
" setp.ne.u32 pr0, %4, 0;\n"
" setp.ne.u32 pr1, %5, 0;\n"
" setp.ne.u32 pr2, %6, 0;\n"
" setp.ne.u32 pr3, %7, 0;\n"
" mov.u32 %0, 0xFF800000;\n"
" mov.u32 %1, 0xFF800000;\n"
" mov.u32 %2, 0xFF800000;\n"
" mov.u32 %3, 0xFF800000;\n"
" @pr0 ld.global.cg.u32 %0, [%8];\n"
" @pr1 ld.global.cg.u32 %1, [%8+4];\n"
" @pr2 ld.global.cg.u32 %2, [%8+8];\n"
" @pr3 ld.global.cg.u32 %3, [%8+12];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "r"(p0), "r"(p1), "r"(p2), "r"(p3), "l"(ptr));
v0 = __uint_as_float(r0);
v1 = __uint_as_float(r1);
v2 = __uint_as_float(r2);
v3 = __uint_as_float(r3);
}
// ============================================================================
// Large path: inter-CTA coordination state (one per group)
// ============================================================================
struct RadixRowState {
uint32_t histogram[3][256]; // Triple-buffered histograms
uint32_t remaining_k;
uint32_t prefix;
int arrival_counter;
int output_counter;
};
// ============================================================================
// Kernel parameters
// ============================================================================
struct PersistentTopKParams {
const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, TopK]
int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state
uint32_t num_rows;
uint32_t stride;
uint32_t chunk_size; // large path: elements per CTA
uint32_t ctas_per_group; // 1=medium, >1=large
uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit)
};
// ============================================================================
// Decode path: 2048-bin histogram for short sequences (seq_len <= 8192)
// Uses 11-bit half-precision bins for fine granularity.
// One histogram pass typically suffices since 8192/2048 = 4 elements/bin avg.
// ============================================================================
// 11-bit bin from half-precision representation (ascending: high values -> high
// bins)
__device__ __forceinline__ uint32_t decode_bin(float x) {
__half hx = __float2half(x);
uint16_t bits = __half_as_ushort(hx);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return key >> 5;
}
__device__ __noinline__ void histogram_2048_topk(
const float* __restrict__ logits, int32_t* __restrict__ output_indices,
int32_t seq_len) {
extern __shared__ int decode_smem[];
const int tx = threadIdx.x;
const int lane = tx & 31;
// ---- Layout constants ----
constexpr int SBASE = 8192 - 8; // 8184
constexpr int RHIST = RADIX + 128; // 384
constexpr int BOFF = 2 * RHIST; // 768
constexpr int DBUF = (SBASE - BOFF) / 2; // 3708
constexpr int MAX_ITEMS_PER_THREAD =
(HIST2048_THRESHOLD + kThreadsPerBlock - 1) / kThreadsPerBlock;
enum : int { sTHR = 0, sOUT = 1, sREF = 2, sFIN = 3, sBUF0 = 4, sBUF1 = 5 };
// ---- Initialize scalars (prevents stale data from prior rows) ----
if (tx < 8) {
decode_smem[SBASE + tx] = 0;
}
// ---- Phase 1: Build 2048-bin histogram with float4 vectorized loads ----
int* histo = decode_smem;
uint16_t reg_bins[MAX_ITEMS_PER_THREAD];
int nitems = 0;
for (int i = tx; i < kDecodeBins; i += kThreadsPerBlock) {
histo[i] = 0;
}
__syncthreads();
const int n_vec = (seq_len + 3) >> 2;
const bool row_aligned = ((reinterpret_cast<uintptr_t>(logits) & 15) == 0);
for (int i = tx; i < n_vec; i += kThreadsPerBlock) {
const int base = i << 2;
float v0, v1, v2, v3;
if (row_aligned && base + 3 < seq_len) {
load_float4(logits + base, v0, v1, v2, v3);
} else {
load_float4_predicated(logits + base, base, seq_len, v0, v1, v2, v3);
}
const uint16_t b0 = static_cast<uint16_t>(decode_bin(v0));
const uint16_t b1 = static_cast<uint16_t>(decode_bin(v1));
const uint16_t b2 = static_cast<uint16_t>(decode_bin(v2));
const uint16_t b3 = static_cast<uint16_t>(decode_bin(v3));
reg_bins[nitems++] = b0;
reg_bins[nitems++] = b1;
reg_bins[nitems++] = b2;
reg_bins[nitems++] = b3;
atomicAdd(&histo[b0], 1);
atomicAdd(&histo[b1], 1);
atomicAdd(&histo[b2], 1);
atomicAdd(&histo[b3], 1);
}
__syncthreads();
// ---- CUB suffix sum ----
using BlockScanT = cub::BlockScan<int, kThreadsPerBlock>;
const int h0 = histo[2 * tx];
const int pair_sum = h0 + histo[2 * tx + 1];
auto& scan_storage = *reinterpret_cast<typename BlockScanT::TempStorage*>(
decode_smem + kDecodeBins);
int pair_prefix, total;
BlockScanT(scan_storage).ExclusiveSum(pair_sum, pair_prefix, total);
// Find threshold bin purely from registers
const int pair_suffix = total - pair_prefix;
if (pair_suffix >= TopK && (pair_suffix - h0) < TopK) {
decode_smem[SBASE + sTHR] = 2 * tx;
}
{
const int right_suf = pair_suffix - h0;
const int next_suf = pair_suffix - pair_sum;
if (right_suf >= TopK && next_suf < TopK) {
decode_smem[SBASE + sTHR] = 2 * tx + 1;
}
}
__syncthreads();
const int threshold = decode_smem[SBASE + sTHR];
// ---- Phase 2: Collection with warp-aggregated atomicAdds ----
int* bufs[2] = {decode_smem + BOFF, decode_smem + BOFF + DBUF};
const int sOUT_abs = SBASE + sOUT;
const int sBUF0_abs = SBASE + sBUF0;
{
const uint32_t uthr = static_cast<uint32_t>(threshold);
int item = 0;
const int n_vec_iters = (n_vec + kThreadsPerBlock - 1) / kThreadsPerBlock;
for (int iter = 0; iter < n_vec_iters; iter++) {
const int i = tx + iter * kThreadsPerBlock;
const bool vec_valid = (i < n_vec);
const int base_idx = i << 2;
#pragma unroll 4
for (int sub = 0; sub < 4; sub++) {
const int elem_idx = base_idx + sub;
uint32_t bin = 0;
if (vec_valid) bin = reg_bins[item++];
const bool is_above = vec_valid && (bin > uthr);
const bool is_equal = vec_valid && (bin == uthr);
const uint32_t above_mask = __ballot_sync(0xffffffff, is_above);
if (above_mask) {
const int above_count = __popc(above_mask);
const int above_rank = __popc(above_mask & ((1u << lane) - 1));
int above_base;
if (lane == 0) {
above_base = atomicAdd(&decode_smem[sOUT_abs], above_count);
}
above_base = __shfl_sync(0xffffffff, above_base, 0);
if (is_above) {
output_indices[above_base + above_rank] = elem_idx;
}
}
const uint32_t equal_mask = __ballot_sync(0xffffffff, is_equal);
if (equal_mask) {
const int equal_count = __popc(equal_mask);
const int equal_rank = __popc(equal_mask & ((1u << lane) - 1));
int equal_base;
if (lane == 0) {
equal_base = atomicAdd(&decode_smem[sBUF0_abs], equal_count);
}
equal_base = __shfl_sync(0xffffffff, equal_base, 0);
if (is_equal && __builtin_expect(equal_base + equal_rank < DBUF, 1)) {
bufs[0][equal_base + equal_rank] = elem_idx;
}
}
}
}
}
__syncthreads();
int remaining_k = TopK - decode_smem[SBASE + sOUT];
if (remaining_k <= 0) return;
// If all buffered elements fit, output them all (common for short seqs)
const int raw_buf0 = decode_smem[SBASE + sBUF0];
if (raw_buf0 <= remaining_k) {
const int nb = (raw_buf0 < DBUF) ? raw_buf0 : DBUF;
const int base = decode_smem[SBASE + sOUT];
for (int i = tx; i < nb; i += kThreadsPerBlock) {
output_indices[base + i] = bufs[0][i];
}
__syncthreads();
return;
}
// ---- Phase 3: Deferred refinement (rare path) ----
int* refine[2] = {decode_smem, decode_smem + RHIST};
const int num_buf0 = (raw_buf0 < DBUF) ? raw_buf0 : DBUF;
for (int i = tx; i < RHIST; i += kThreadsPerBlock) {
refine[0][i] = 0;
}
__syncthreads();
for (int i = tx; i < num_buf0; i += kThreadsPerBlock) {
const uint32_t fp32 = convert_to_uint32_v2(logits[bufs[0][i]]);
atomicAdd(&refine[0][(fp32 >> 24) & 0xFF], 1);
}
__syncthreads();
auto compute_suffix_sum = [&]() {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (tx < RADIX) {
const int stride = 1 << i;
const int s = i & 1;
const int d = s ^ 1;
int value = refine[s][tx];
if (tx < RADIX - stride) value += refine[s][tx + stride];
refine[d][tx] = value;
}
__syncthreads();
}
};
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
const int src = pass & 1;
const int dst = src ^ 1;
const int raw_buf = decode_smem[SBASE + sBUF0 + src];
const int num_buffered = (raw_buf < DBUF) ? raw_buf : DBUF;
compute_suffix_sum();
if (tx < RADIX && refine[0][tx] > remaining_k &&
refine[0][tx + 1] <= remaining_k) {
decode_smem[SBASE + sREF] = tx;
decode_smem[SBASE + sBUF0 + dst] = 0;
decode_smem[SBASE + sFIN] = remaining_k - refine[0][tx + 1];
}
__syncthreads();
const int ref_thr = decode_smem[SBASE + sREF];
remaining_k -= refine[0][ref_thr + 1];
const int bit_offset = 24 - pass * 8;
if (remaining_k == 0) {
for (int i = tx; i < num_buffered; i += kThreadsPerBlock) {
const int idx = bufs[src][i];
const uint32_t fp32 = convert_to_uint32_v2(logits[idx]);
if (((fp32 >> bit_offset) & 0xFF) > static_cast<uint32_t>(ref_thr)) {
const int pos = atomicAdd(&decode_smem[SBASE + sOUT], 1);
output_indices[pos] = idx;
}
}
__syncthreads();
break;
}
__syncthreads();
if (tx < RADIX + 1) refine[0][tx] = 0;
__syncthreads();
for (int i = tx; i < num_buffered; i += kThreadsPerBlock) {
const int idx = bufs[src][i];
const float logit_val = logits[idx];
const uint32_t fp32 = convert_to_uint32_v2(logit_val);
const int bin = (fp32 >> bit_offset) & 0xFF;
if (bin > ref_thr) {
const int pos = atomicAdd(&decode_smem[SBASE + sOUT], 1);
output_indices[pos] = idx;
} else if (bin == ref_thr) {
if (pass == 3) {
const int slot = atomicAdd(&decode_smem[SBASE + sFIN], -1);
if (slot > 0) output_indices[TopK - slot] = idx;
} else {
const int bp = atomicAdd(&decode_smem[SBASE + sBUF0 + dst], 1);
if (__builtin_expect(bp < DBUF, 1)) {
bufs[dst][bp] = idx;
const int nbo = bit_offset - 8;
atomicAdd(&refine[0][(fp32 >> nbo) & 0xFF], 1);
}
}
}
}
__syncthreads();
}
}
// ============================================================================
// Medium path: coarse FP16 histogram + 4-pass FP32 radix refinement
// For sequences 8K < seq_len <= 64K.
// ============================================================================
// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ __noinline__ void histogram_256_topk(
const float* __restrict__ logits, int* __restrict__ output_indices,
int logits_offset, int seq_len) {
// All shared state lives in dynamic shared memory to avoid static
extern __shared__ char medium_smem[];
int (*shared_histogram)[RADIX + 128] =
reinterpret_cast<int (*)[RADIX + 128]>(medium_smem);
int* medium_scalars = reinterpret_cast<int*>(medium_smem + kMediumHistBytes);
int& shared_output_count = medium_scalars[0];
int& shared_threshold_bin = medium_scalars[1];
int* shared_buffered_count = &medium_scalars[2];
int& shared_final_k = medium_scalars[4];
int (*buffered_indices)[MAX_BUFFERED_ITEMS] =
reinterpret_cast<int (*)[MAX_BUFFERED_ITEMS]>(medium_smem +
kMediumHeaderSize);
const int thread_id = threadIdx.x;
int remaining_k = TopK;
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
atomicAdd(&shared_histogram[0][bin], 1);
}
__syncthreads();
auto compute_cumulative_sum = [&]() {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (__builtin_expect(thread_id < RADIX, 1)) {
const int stride = 1 << i;
const int src_buffer = i & 1;
const int dst_buffer = src_buffer ^ 1;
int value = shared_histogram[src_buffer][thread_id];
if (thread_id < RADIX - stride) {
value += shared_histogram[src_buffer][thread_id + stride];
}
shared_histogram[dst_buffer][thread_id] = value;
}
__syncthreads();
}
};
compute_cumulative_sum();
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[0] = 0;
shared_output_count = 0;
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
if (remaining_k == 0) {
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const int bin = convert_to_uint8(logits[idx + logits_offset]);
if (bin > threshold_bin) {
const int output_pos = atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
return;
}
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const float logit_value = logits[idx + logits_offset];
const int bin = convert_to_uint8(logit_value);
if (bin > threshold_bin) {
const int output_pos = atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
const int buffer_pos = atomicAdd(&shared_buffered_count[0], 1);
if (__builtin_expect(buffer_pos < MAX_BUFFERED_ITEMS, 1)) {
buffered_indices[0][buffer_pos] = idx;
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int next_bin = (fp32_bits >> 24) & 0xFF;
atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
__syncthreads();
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
const int src_buffer = pass % 2;
const int dst_buffer = src_buffer ^ 1;
const int raw_buffered = shared_buffered_count[src_buffer];
const int num_buffered =
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
compute_cumulative_sum();
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[dst_buffer] = 0;
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
const int bit_offset = 24 - pass * 8;
if (remaining_k == 0) {
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const uint32_t fp32_bits =
convert_to_uint32_v2(logits[idx + logits_offset]);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
break;
}
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const float logit_value = logits[idx + logits_offset];
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
if (pass == 3) {
const int slot = atomicAdd(&shared_final_k, -1);
if (slot > 0) {
output_indices[TopK - slot] = idx;
}
} else {
const int buffer_pos =
atomicAdd(&shared_buffered_count[dst_buffer], 1);
if (__builtin_expect(buffer_pos < MAX_BUFFERED_ITEMS, 1)) {
buffered_indices[dst_buffer][buffer_pos] = idx;
const int next_bit_offset = bit_offset - 8;
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
}
__syncthreads();
}
}
// ============================================================================
// Inter-CTA sync primitives
// ============================================================================
__device__ __forceinline__ int ld_acquire(int* ptr) {
int state = 0;
#if (__CUDA_ARCH__ >= 700)
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(ptr));
#else
asm volatile("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#endif
return state;
}
__device__ __forceinline__ void red_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(ptr), "r"(val));
#else
__threadfence();
atomicAdd(ptr, val);
#endif
}
__device__ __forceinline__ void st_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicExch(ptr, val);
#endif
}
__device__ __forceinline__ void wait_ge(int* ptr, int target_val,
int thread_idx) {
if (thread_idx == 0) {
#pragma unroll 1
while (ld_acquire(ptr) < target_val) {
}
}
__syncthreads();
}
// ============================================================================
// Large path: multi-CTA radix select for sequences > 64K
//
// Each row is processed by a group of CTAs. Each CTA loads its chunk into
// shared memory as ordered uint32, then participates in 4 rounds of
// coordinated radix select via global-memory histograms and barriers.
// ============================================================================
// ============================================================================
// Multi-CTA cooperative RadixTopK for a single large row.
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
// ============================================================================
template <uint32_t VEC_SIZE>
__device__ void radix_topk(const float* __restrict__ row_input,
int32_t* __restrict__ row_output, uint32_t seq_len,
uint32_t my_chunk_start, uint32_t chunk_size,
uint32_t* local_histogram, uint32_t* suffix_sum,
uint32_t* shared_scalars, uint32_t* shared_ordered,
RadixRowState* state, uint32_t cta_in_group,
uint32_t ctas_per_group, int& barrier_phase,
uint32_t iter, uint32_t tx) {
const uint32_t my_chunk_end = (my_chunk_start + chunk_size < seq_len)
? my_chunk_start + chunk_size
: seq_len;
const uint32_t actual_chunk_size =
(my_chunk_start < seq_len) ? (my_chunk_end - my_chunk_start) : 0;
// -- Stage 1: Load chunk to shared memory as ordered uint32 --
{
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;
for (uint32_t i = tx * VEC_SIZE; i < aligned_size;
i += kThreadsPerBlock * VEC_SIZE) {
const float* src = row_input + my_chunk_start + i;
if constexpr (VEC_SIZE == 4) {
float4 v = *reinterpret_cast<const float4*>(src);
shared_ordered[i] = convert_to_uint32_v2(v.x);
shared_ordered[i + 1] = convert_to_uint32_v2(v.y);
shared_ordered[i + 2] = convert_to_uint32_v2(v.z);
shared_ordered[i + 3] = convert_to_uint32_v2(v.w);
} else if constexpr (VEC_SIZE == 2) {
float2 v = *reinterpret_cast<const float2*>(src);
shared_ordered[i] = convert_to_uint32_v2(v.x);
shared_ordered[i + 1] = convert_to_uint32_v2(v.y);
} else {
shared_ordered[i] = convert_to_uint32_v2(*src);
}
}
for (uint32_t i = aligned_size + tx; i < actual_chunk_size;
i += kThreadsPerBlock) {
shared_ordered[i] = convert_to_uint32_v2(row_input[my_chunk_start + i]);
}
}
__syncthreads();
// -- Init radix select state --
if (tx == 0) {
shared_scalars[0] = 0; // prefix
shared_scalars[1] = TopK; // remaining_k
}
__syncthreads();
// -- Initial barrier --
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
wait_ge(&state->arrival_counter,
(barrier_phase + 1) * static_cast<int>(ctas_per_group), tx);
barrier_phase++;
__syncthreads();
if (cta_in_group == 0 && tx == 0) {
st_release(&state->output_counter, 0);
}
// -- Stage 2: 4 rounds of radix select --
for (uint32_t round = 0; round < 4; round++) {
const uint32_t global_round = iter * 4 + round;
const uint32_t shift = 24 - round * 8;
const uint32_t prefix = shared_scalars[0];
const uint32_t remaining_k = shared_scalars[1];
uint32_t* current_hist = state->histogram[global_round % 3];
uint32_t* next_hist = state->histogram[(global_round + 1) % 3];
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
local_histogram[i] = 0;
}
__syncthreads();
for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) {
uint32_t ordered = shared_ordered[i];
uint32_t mask = (round == 0) ? 0u : (~0u << (32 - round * 8));
if ((ordered & mask) == prefix) {
uint32_t bucket = (ordered >> shift) & 0xFF;
atomicAdd(&local_histogram[bucket], 1);
}
}
__syncthreads();
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
if (local_histogram[i] > 0) {
atomicAdd(&current_hist[i], local_histogram[i]);
}
}
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
next_hist[i] = 0;
}
}
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
wait_ge(&state->arrival_counter,
(barrier_phase + 1) * static_cast<int>(ctas_per_group), tx);
barrier_phase++;
__syncthreads();
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
suffix_sum[i] = current_hist[i];
}
__syncthreads();
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) val += suffix_sum[tx + stride];
}
__syncthreads();
if (tx < RADIX) suffix_sum[tx] = val;
__syncthreads();
}
if (tx == 0) {
shared_scalars[2] = 0;
shared_scalars[3] = remaining_k;
}
__syncthreads();
if (tx < RADIX) {
uint32_t count_ge = suffix_sum[tx];
uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0;
if (count_ge >= remaining_k && count_gt < remaining_k) {
shared_scalars[2] = tx;
shared_scalars[3] = remaining_k - count_gt;
}
}
__syncthreads();
if (tx == 0) {
shared_scalars[0] = prefix | (shared_scalars[2] << shift);
shared_scalars[1] = shared_scalars[3];
}
__syncthreads();
} // end 4 radix rounds
// -- Count local > pivot elements --
const uint32_t ordered_pivot = shared_scalars[0];
if (tx == 0) suffix_sum[0] = 0;
__syncthreads();
uint32_t my_gt_count = 0;
for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) {
if (shared_ordered[i] > ordered_pivot) my_gt_count++;
}
for (int offset = 16; offset > 0; offset /= 2) {
my_gt_count += __shfl_down_sync(0xffffffff, my_gt_count, offset);
}
if (tx % 32 == 0 && my_gt_count > 0) {
atomicAdd(&suffix_sum[0], my_gt_count);
}
__syncthreads();
const uint32_t local_gt_count = suffix_sum[0];
// -- Stage 3: Collect top-k indices --
if (tx == 0) {
local_histogram[0] = 0;
if (local_gt_count > 0) {
local_histogram[1] =
atomicAdd(&state->output_counter, static_cast<int>(local_gt_count));
}
}
__syncthreads();
for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) {
if (shared_ordered[i] > ordered_pivot) {
uint32_t local_pos = atomicAdd(&local_histogram[0], 1);
int pos = static_cast<int>(local_histogram[1]) + local_pos;
row_output[pos] = static_cast<int32_t>(my_chunk_start + i);
}
}
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
wait_ge(&state->arrival_counter,
(barrier_phase + 1) * static_cast<int>(ctas_per_group), tx);
barrier_phase++;
__syncthreads();
for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) {
if (shared_ordered[i] == ordered_pivot) {
int pos = atomicAdd(&state->output_counter, 1);
if (pos < TopK) {
row_output[pos] = static_cast<int32_t>(my_chunk_start + i);
}
}
}
}
// ============================================================================
// Persistent kernel — BS≤32, decode/medium/large paths with RadixTopK
// BS>32 uses standalone histogram_256_buffered_topk (separate kernel,
// see filtered_topk.cuh)
// ============================================================================
template <uint32_t VEC_SIZE = 1>
__global__ void __launch_bounds__(kThreadsPerBlock, 2)
persistent_topk_kernel(PersistentTopKParams params) {
const uint32_t tx = threadIdx.x;
extern __shared__ uint8_t smem_raw[];
// ========================================================================
// Group mode: multi-CTA groups with static round-robin row assignment.
// Non-large rows: CTA-0 handles trivial/decode/medium.
// Large rows: all CTAs in the group cooperate via RadixTopK.
// ========================================================================
const uint32_t ctas_per_group = params.ctas_per_group;
const uint32_t group_id = blockIdx.x / ctas_per_group;
const uint32_t cta_in_group = blockIdx.x % ctas_per_group;
const uint32_t num_groups = gridDim.x / ctas_per_group;
const uint32_t chunk_size = params.chunk_size;
if (blockIdx.x >= num_groups * ctas_per_group) return;
// Early exit: non-CTA-0 threads are never needed if no large rows exist
if (cta_in_group != 0 && params.max_seq_len <= RADIX_THRESHOLD) return;
uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem_raw);
uint32_t* suffix_sum = local_histogram + RADIX;
uint32_t* shared_scalars = suffix_sum + RADIX;
uint32_t* shared_ordered =
reinterpret_cast<uint32_t*>(smem_raw + kFixedSmemLarge);
// RadixRowState for multi-CTA cooperative radix
RadixRowState* state = &params.row_states[group_id];
// -- Initialize RadixRowState (only needed if large rows exist) --
if (params.max_seq_len > RADIX_THRESHOLD) {
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; buf++) {
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
state->histogram[buf][i] = 0;
}
}
if (tx == 0) {
state->remaining_k = 0;
state->prefix = 0;
state->arrival_counter = 0;
state->output_counter = 0;
}
}
__syncthreads();
}
int barrier_phase = 0;
const uint32_t total_iters = (params.num_rows + num_groups - 1) / num_groups;
for (uint32_t iter = 0; iter < total_iters; iter++) {
// Static round-robin: all CTAs in the group implicitly agree on the row
uint32_t row_idx = group_id + iter * num_groups;
if (row_idx >= params.num_rows) break;
const uint32_t seq_len = params.lengths[row_idx];
int32_t* row_output = params.output + row_idx * TopK;
const float* row_input = params.input + row_idx * params.stride;
if (seq_len <= RADIX_THRESHOLD) {
if (cta_in_group == 0) {
if (seq_len <= static_cast<uint32_t>(TopK)) {
// Trivial case: seq_len <= TopK
for (uint32_t i = tx; i < static_cast<uint32_t>(TopK);
i += kThreadsPerBlock) {
row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1;
}
} else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) {
histogram_2048_topk(row_input, row_output, seq_len);
} else {
histogram_256_topk(row_input, row_output, 0, seq_len);
}
}
continue;
}
const uint32_t my_chunk_start = cta_in_group * chunk_size;
radix_topk<VEC_SIZE>(row_input, row_output, seq_len, my_chunk_start,
chunk_size, local_histogram, suffix_sum,
shared_scalars, shared_ordered, state, cta_in_group,
ctas_per_group, barrier_phase, iter, tx);
}
}
} // namespace persistent
// ============================================================================
// FlashInfer FilteredTopK (BS>32 dispatch) — float32 only.
// Extracted from flashinfer_topk.cuh. Lives in namespace vllm (not persistent).
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
// ============================================================================
#define FLASHINFER_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
}
#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__
template <typename T, size_t N>
struct vec_t {
T data[N];
FLASHINFER_INLINE T& operator[](size_t i) { return data[i]; }
FLASHINFER_INLINE const T& operator[](size_t i) const { return data[i]; }
FLASHINFER_INLINE void cast_load(const T* ptr) {
#pragma unroll
for (size_t i = 0; i < N; ++i) {
data[i] = ptr[i];
}
}
FLASHINFER_INLINE void cast_store(T* ptr) const {
#pragma unroll
for (size_t i = 0; i < N; ++i) {
ptr[i] = data[i];
}
}
};
#undef FLASHINFER_INLINE
// FilteredTopK traits for different data types
template <typename DType>
struct FilteredTopKTraits;
// Specialization for float (32-bit): coarse histogram uses FP16 high 8 bits, 4
// refinement rounds
template <>
struct FilteredTopKTraits<float> {
using OrderedType = uint32_t;
static constexpr int NUM_REFINE_ROUNDS = 4;
static constexpr int FIRST_REFINE_SHIFT = 24;
__device__ __forceinline__ static uint8_t ToCoarseKey(float x) {
// Convert to FP16 representation and extract high 8 bits
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ __forceinline__ static OrderedType ToOrdered(float x) {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
};
constexpr uint32_t FILTERED_TOPK_MAX_K = 2048;
constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024;
constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE =
16 * 1024; // 16K indices per buffer
constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC =
sizeof(int) * 2 * FILTERED_TOPK_SMEM_INPUT_SIZE; // 128KB
/*!
* \brief Filtered Top-K kernel for ragged sequences.
*
* \tparam DType Data type (float, half, nv_bfloat16)
* \tparam IdType Index type (int32_t)
* \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8)
*/
template <typename DType, typename IdType, int VEC_SIZE>
__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
FilteredTopKUnifiedKernel(const DType* __restrict__ input,
IdType* __restrict__ output,
const IdType* __restrict__ lengths,
uint32_t num_rows, uint32_t top_k,
uint32_t max_len) {
constexpr uint32_t BLOCK_SIZE = FILTERED_TOPK_BLOCK_THREADS;
constexpr int RADIX = 256;
constexpr int SMEM_INPUT_SIZE = FILTERED_TOPK_SMEM_INPUT_SIZE;
const uint32_t bid = blockIdx.x;
const int tx = threadIdx.x;
if (bid >= num_rows) return;
const int length =
(lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
const DType* score = input + bid * max_len;
IdType* dst = output + bid * top_k;
// Trivial case: length <= top_k
if (length <= static_cast<int>(top_k)) {
for (int i = tx; i < static_cast<int>(top_k); i += BLOCK_SIZE) {
dst[i] = (i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
}
return;
}
// Static shared memory
alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128];
alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2];
alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K];
auto& s_histogram = s_histogram_buf[0];
// Dynamic shared memory for input double buffer
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
using Traits = FilteredTopKTraits<DType>;
int topk = top_k;
// Stage 1: 8-bit coarse histogram with vectorized loads
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();
vec_t<DType, VEC_SIZE> score_vec;
const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;
#pragma unroll 2
for (int base = tx * VEC_SIZE; base < aligned_length;
base += BLOCK_SIZE * VEC_SIZE) {
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
const auto bin = Traits::ToCoarseKey(score_vec[j]);
atomicAdd(&s_histogram[bin], 1);
}
}
// Handle tail
for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
const auto bin = Traits::ToCoarseKey(score[i]);
atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
// Suffix sum
const auto run_cumsum = [&]() {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (tx < RADIX) {
const auto j = 1 << i;
const auto k = i & 1;
auto value = s_histogram_buf[k][tx];
if (tx < RADIX - j) {
value += s_histogram_buf[k][tx + j];
}
s_histogram_buf[k ^ 1][tx] = value;
}
__syncthreads();
}
};
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[0] = 0;
s_counter = 0;
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
topk -= s_histogram[threshold_bin + 1];
constexpr int NUM_ROUNDS = Traits::NUM_REFINE_ROUNDS;
constexpr int FIRST_SHIFT = Traits::FIRST_REFINE_SHIFT;
if (topk == 0) {
// Collect indices where bin > threshold
#pragma unroll 2
for (int base = tx * VEC_SIZE; base < aligned_length;
base += BLOCK_SIZE * VEC_SIZE) {
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(score_vec[j]));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = base + j;
}
}
}
// Handle tail
for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(score[i]));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = i;
}
}
__syncthreads();
} else {
__syncthreads();
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();
// Filter + histogram for refinement
auto filter_and_add_to_histogram = [&](auto raw_input, int index) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(raw_input));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = index;
} else if (bin == threshold_bin) {
const auto pos = atomicAdd(&s_num_input[0], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
s_input_idx[0][pos] = index;
const auto ordered = Traits::ToOrdered(raw_input);
const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
}
}
};
#pragma unroll 2
for (int base = tx * VEC_SIZE; base < aligned_length;
base += BLOCK_SIZE * VEC_SIZE) {
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
filter_and_add_to_histogram(score_vec[j], base + j);
}
}
// Handle tail
for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
filter_and_add_to_histogram(score[i], i);
}
__syncthreads();
// Stage 2: refine with 8bit radix passes
#pragma unroll
for (int round = 0; round < NUM_ROUNDS; ++round) {
__shared__ int s_last_remain;
const auto r_idx = round % 2;
const auto _raw_num_input = s_num_input[r_idx];
const auto num_input =
(_raw_num_input < SMEM_INPUT_SIZE) ? _raw_num_input : SMEM_INPUT_SIZE;
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[r_idx ^ 1] = 0;
s_last_remain = topk - s_histogram[tx + 1];
}
__syncthreads();
const auto threshold = s_threshold_bin_id;
topk -= s_histogram[threshold + 1];
const int offset = FIRST_SHIFT - round * 8;
const bool is_last_round = (round == NUM_ROUNDS - 1);
if (topk == 0) {
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto bin = (Traits::ToOrdered(score[idx]) >> offset) & 0xFF;
if (static_cast<int>(bin) > threshold) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = idx;
}
}
__syncthreads();
break;
} else {
__syncthreads();
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto raw_input = score[idx];
const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF;
if (static_cast<int>(bin) > threshold) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = idx;
} else if (static_cast<int>(bin) == threshold) {
if (is_last_round) {
const auto pos = atomicAdd(&s_last_remain, -1);
if (pos > 0) {
s_indices[top_k - pos] = idx;
}
} else {
const auto pos = atomicAdd(&s_num_input[r_idx ^ 1], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
s_input_idx[r_idx ^ 1][pos] = idx;
const auto bin32 = Traits::ToOrdered(raw_input);
const auto sub_bin = (bin32 >> (offset - 8)) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
}
__syncthreads();
}
}
}
// Output phase - mode-specific
#pragma unroll 2
for (int base = tx; base < static_cast<int>(top_k); base += BLOCK_SIZE) {
const int idx = s_indices[base];
dst[base] = static_cast<IdType>(idx);
}
}
// Helper to compute GCD for VEC_SIZE selection
constexpr uint32_t gcd(uint32_t a, uint32_t b) {
while (b != 0) {
uint32_t t = b;
b = a % b;
a = t;
}
return a;
}
// Compute optimal VEC_SIZE based on max_len and dtype
// Returns 1, 2, 4, or 8
template <typename DType>
constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
constexpr int MAX_VEC = 16 / sizeof(DType); // 4 for float32, 8 for fp16/bf16
// Use GCD to find largest power-of-2 divisor
const uint32_t g = gcd(max_len, static_cast<uint32_t>(MAX_VEC));
return static_cast<int>(g);
}
template <typename DType, typename IdType>
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
IdType* lengths, uint32_t num_rows,
uint32_t top_k_val, uint32_t max_len,
cudaStream_t stream = 0) {
constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
constexpr int MAX_VEC = 16 / sizeof(DType);
dim3 grid(num_rows);
dim3 block(FILTERED_TOPK_BLOCK_THREADS);
void* args[] = {&input, &output_indices, &lengths,
&num_rows, &top_k_val, &max_len};
const int vec_size = ComputeFilteredTopKVecSize<DType>(max_len);
#define DISPATCH_VEC_SIZE(VS) \
if (vec_size == VS) { \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS>; \
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \
smem_size, stream)); \
return cudaSuccess; \
}
DISPATCH_VEC_SIZE(1)
DISPATCH_VEC_SIZE(2)
DISPATCH_VEC_SIZE(4)
if constexpr (MAX_VEC >= 8) {
DISPATCH_VEC_SIZE(8)
}
#undef DISPATCH_VEC_SIZE
return cudaSuccess;
}
} // namespace vllm
#endif // PERSISTENT_TOPK_CUH_
// Portions of this file are adapted from SGLang PR:
// https://github.com/sgl-project/sglang/pull/11194
// and
// https://github.com/sgl-project/sglang/pull/17747
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
// See persistent_topk.cuh for kernel implementation.
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <algorithm>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#include "persistent_topk.cuh"
#endif
namespace vllm {
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
constexpr int kThreadsPerBlock = 1024; // Threads per block
// Shared memory budget
#if defined(USE_ROCM)
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
#else
// Reduced from 128KB to 32KB to improve occupancy.
// Each radix pass needs at most ~TopK candidates in the threshold bin,
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
#endif
struct FastTopKParams {
const float* __restrict__ input; // [batch, seq_len] Logits
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
// (optional)
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
int64_t input_stride; // Stride between rows
};
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ void naive_topk_cuda(const float* __restrict__ logits,
int32_t* __restrict__ output_indices,
int32_t seq_len) {
const int thread_id = threadIdx.x;
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
output_indices[i] = (i < seq_len) ? i : -1;
}
}
// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ void fast_topk_cuda_tl(
const float* __restrict__ logits, // Input logits [seq_len]
int* __restrict__ output_indices, // Output top-k indices [TopK]
int logits_offset, // Starting offset in logits array
int seq_len) // Number of valid logits to process
{
constexpr int RADIX = 256;
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
alignas(128) __shared__ int shared_output_count;
alignas(128) __shared__ int shared_threshold_bin;
alignas(128) __shared__ int shared_buffered_count[2];
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
const int thread_id = threadIdx.x;
int remaining_k = TopK;
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
::atomicAdd(&shared_histogram[0][bin], 1);
}
__syncthreads();
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
// buffers
auto compute_cumulative_sum = [&]() {
static_assert(1 << 8 == RADIX,
"Radix must be 256 for 8 unrolled iterations");
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (C10_LIKELY(thread_id < RADIX)) {
const int stride = 1 << i;
const int src_buffer = i & 1;
const int dst_buffer = src_buffer ^ 1;
int value = shared_histogram[src_buffer][thread_id];
if (thread_id < RADIX - stride) {
value += shared_histogram[src_buffer][thread_id + stride];
}
shared_histogram[dst_buffer][thread_id] = value;
}
__syncthreads();
}
};
compute_cumulative_sum();
// Find threshold bin where cumsum crosses remaining_k
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[0] = 0;
shared_output_count = 0;
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Early exit if threshold bin perfectly matches remaining_k
if (remaining_k == 0) {
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const int bin = convert_to_uint8(logits[idx + logits_offset]);
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
return;
}
// Prepare for refinement passes: Process threshold bin
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
// Scan all elements and:
// 1. Write indices > threshold_bin to output
// 2. Buffer indices == threshold_bin for refinement
// 3. Build histogram for next refinement pass (fused optimization)
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const float logit_value = logits[idx + logits_offset];
const int bin = convert_to_uint8(logit_value);
if (bin > threshold_bin) {
// in top-k, write to output
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
// Candidate for top-k, needs refinement
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[0][buffer_pos] = idx;
// Fused: Build histogram for next pass
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int next_bin = (fp32_bits >> 24) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;
TORCH_CHECK(k == P::TopK, "k must be 2048");
TORCH_CHECK(k <= stride, "k out of range");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static int num_sms = 0;
static int max_smem_per_block = 0;
if (num_sms == 0) {
int device;
cudaGetDevice(&device);
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
cudaDeviceGetAttribute(&max_smem_per_block,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
}
__syncthreads();
// ============================================================================
// Passes 1-4: Refine using 8-bit passes over FP32 bits
// ============================================================================
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
// bits [7:0]
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
__shared__ int shared_final_k; // For final pass: remaining slots to fill
const int src_buffer = pass % 2;
const int dst_buffer = src_buffer ^ 1;
// Clamp buffered count to prevent overflow
const int raw_buffered = shared_buffered_count[src_buffer];
const int num_buffered =
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
compute_cumulative_sum();
// Find threshold bin for this pass
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[dst_buffer] = 0;
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Bit offset for this pass: 24, 16, 8, 0
const int bit_offset = 24 - pass * 8;
// Early exit if threshold bin perfectly matches
if (remaining_k == 0) {
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const uint32_t fp32_bits =
convert_to_uint32_v2(logits[idx + logits_offset]);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
break;
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status));
} else {
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
// large path. Empirically tuned.
int effective_max_smem;
if (num_rows <= 4) {
effective_max_smem =
std::min(max_smem_per_block, static_cast<int>(P::kSmemMedium));
} else if (num_rows <= 8) {
constexpr int kSmemCapMedium = 48 * 1024;
effective_max_smem = std::min(max_smem_per_block, kSmemCapMedium);
} else {
effective_max_smem = max_smem_per_block;
}
// Continue refinement
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
size_t available_for_ordered =
static_cast<size_t>(effective_max_smem) - P::kFixedSmemLarge;
uint32_t max_chunk_elements =
static_cast<uint32_t>(available_for_ordered / sizeof(uint32_t));
uint32_t vec_size = 1;
if (stride % 4 == 0)
vec_size = 4;
else if (stride % 2 == 0)
vec_size = 2;
max_chunk_elements = (max_chunk_elements / vec_size) * vec_size;
uint32_t min_chunk = vec_size * P::kThreadsPerBlock;
if (max_chunk_elements < min_chunk) max_chunk_elements = min_chunk;
uint32_t ctas_per_group =
(static_cast<uint32_t>(stride) + max_chunk_elements - 1) /
max_chunk_elements;
uint32_t chunk_size =
(static_cast<uint32_t>(stride) + ctas_per_group - 1) / ctas_per_group;
chunk_size = ((chunk_size + vec_size - 1) / vec_size) * vec_size;
if (chunk_size > max_chunk_elements) chunk_size = max_chunk_elements;
size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;
int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
smem_size);
if (occupancy < 1) occupancy = 1;
uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
static_cast<uint32_t>(num_rows));
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");
P::PersistentTopKParams params;
params.input = logits.data_ptr<float>();
params.output = output.data_ptr<int32_t>();
params.lengths = lengths.data_ptr<int32_t>();
params.num_rows = static_cast<uint32_t>(num_rows);
params.stride = static_cast<uint32_t>(stride);
params.chunk_size = chunk_size;
params.row_states =
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
params.ctas_per_group = ctas_per_group;
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
#define LAUNCH_PERSISTENT(VS) \
do { \
auto kernel = &P::persistent_topk_kernel<VS>; \
cudaError_t err = cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
TORCH_CHECK(err == cudaSuccess, \
"Failed to set smem: ", cudaGetErrorString(err)); \
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
} while (0)
if (vec_size == 4) {
LAUNCH_PERSISTENT(4);
} else if (vec_size == 2) {
LAUNCH_PERSISTENT(2);
} else {
LAUNCH_PERSISTENT(1);
}
__syncthreads();
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const float logit_value = logits[idx + logits_offset];
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
// Definitely in top-k
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
if (pass == 3) {
// Final pass (bits [7:0]): No more refinement possible
// Fill remaining slots in reverse order to maintain descending order
const int slot = ::atomicAdd(&shared_final_k, -1);
if (slot > 0) {
output_indices[TopK - slot] = idx;
}
} else {
// Buffer for next pass and build next histogram
const int buffer_pos =
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[dst_buffer][buffer_pos] = idx;
// Fused: Build histogram for next pass
const int next_bit_offset = bit_offset - 8;
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
}
__syncthreads();
#undef LAUNCH_PERSISTENT
}
}
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
const FastTopKParams params) {
const auto& [input, row_starts, indices, lengths, input_stride] = params;
const uint64_t batch_idx = blockIdx.x;
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
const int seq_len = lengths[batch_idx];
int* output_indices = indices + batch_idx * TopK;
const float* logits = input + batch_idx * input_stride;
if (seq_len <= TopK) {
// Shortcut: All elements are in top-k
return naive_topk_cuda(logits, output_indices, seq_len);
} else {
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
}
}
FastTopKParams get_params(
const at::Tensor& score, const at::Tensor& lengths,
std::optional<at::Tensor> row_starts_opt = std::nullopt,
std::optional<at::Tensor> indices_opt = std::nullopt) {
const int64_t batch_size = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
"score must be 2D with contiguous rows");
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
lengths.size(0) == batch_size,
"lengths must be 1D contiguous with size matching batch");
const int32_t* row_starts_ptr = nullptr;
if (row_starts_opt.has_value()) {
const auto& row_starts = *row_starts_opt;
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
"row_starts must be 1D with size matching batch");
row_starts_ptr = row_starts.data_ptr<int32_t>();
}
int32_t* indices_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = *indices_opt;
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
indices.size(0) == batch_size && indices.size(1) == TopK,
"indices must be 2D contiguous [batch, TopK]");
indices_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.row_starts = row_starts_ptr,
.indices = indices_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* kernel_func, size_t smem_bytes>
void setup_kernel_smem_once() {
static const cudaError_t result = []() -> cudaError_t {
#ifdef USE_ROCM
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err));
#else
auto func_ptr = kernel_func;
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
#endif
return cudaFuncSetAttribute(
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}();
TORCH_CHECK(
result == cudaSuccess,
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
}
} // namespace vllm
void large_context_topk(
const torch::Tensor& logits, torch::Tensor& indices,
const torch::Tensor& seq_lens,
std::optional<torch::Tensor> row_starts = std::nullopt) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
if (row_starts.has_value()) {
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
}
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
const int64_t batch_size = logits.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const dim3 grid(static_cast<uint32_t>(batch_size));
const dim3 block(vllm::kThreadsPerBlock);
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
const cudaError_t result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"large_context_topk kernel failed: ", cudaGetErrorString(result));
}
\ No newline at end of file
......@@ -197,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
"Tensor? "
"row_starts_opt) -> ()");
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor workspace, int k, int max_seq_len) -> ()");
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
......
......@@ -122,6 +122,39 @@ def compare_top_k_results(
return True
def validate_topk_against_reference(
logits: torch.Tensor,
cuda_indices: torch.Tensor,
row_starts: torch.Tensor,
row_ends: torch.Tensor,
top_k: int,
kernel_name: str,
) -> None:
"""
Validate CUDA top-k results against PyTorch reference implementation.
Args:
logits: Input logits tensor
cuda_indices: CUDA kernel output indices
row_starts: Row start positions
row_ends: Row end positions
top_k: Number of top elements to select
kernel_name: Name of the kernel being tested (for error messages)
"""
num_rows = cuda_indices.shape[0]
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
for i in range(num_rows):
row_end = int(row_ends[i])
k_i = min(top_k, row_end)
idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices[i, :k_i] = idx
assert compare_top_k_results(
logits, cuda_indices, torch_indices, row_starts, row_ends, top_k
), f"{kernel_name} results don't match torch.topk"
@pytest.mark.parametrize("num_rows", NUM_ROWS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("clean_logits", [True, False])
......@@ -278,111 +311,540 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize(
"seq_len_range,test_id",
[
pytest.param((4000, 8000), "short_sequences", id="short"),
pytest.param((8000, 32000), "medium_sequences", id="medium"),
pytest.param((32000, 163840), "long_sequences", id="long"),
],
)
@pytest.mark.parametrize("clean_logits", [True, False])
@pytest.mark.parametrize("top_k", [2048])
@pytest.mark.parametrize("next_n", [1, 4])
@torch.inference_mode()
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
def test_deepseek_persistent_topk(
seq_len_range: tuple[int, int],
test_id: str,
clean_logits: bool,
top_k: int,
next_n: int,
) -> None:
"""
Test persistent_topk with varying sequence lengths and speculative decoding.
Supports speculative decoding with next_n > 1.
"""
set_random_seed(42 if test_id == "short_sequences" else 43)
torch.set_default_device("cuda:0")
top_k = 2048
# Test case 1: Short sequences (< 8192)
batch_size_short = 4
next_n = 1
num_rows_short = batch_size_short * next_n
batch_size = 4
num_rows = batch_size * next_n
# Create sequences with max length < 8192
seq_lens_short = torch.randint(
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
seq_lens = torch.randint(
seq_len_range[0],
seq_len_range[1],
(batch_size,),
dtype=torch.int32,
device="cuda",
)
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
row_ends_short = (
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
# Compute row boundaries for speculative decoding
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, clean_logits, "random"
)
logits_short = create_random_logits(
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
if next_n == 1:
lengths = seq_lens
else:
offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32)
lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten()
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
max_seq_len = int(seq_lens.max().item())
torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max_seq_len
)
indices_vllm = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
validate_topk_against_reference(
logits, indices, row_starts, row_ends, top_k, f"persistent_topk ({test_id})"
)
# Use vllm's kernel for short sequences
torch.ops._C.top_k_per_row_decode(
logits_short,
next_n,
seq_lens_short,
indices_vllm,
num_rows_short,
logits_short.stride(0),
logits_short.stride(1),
top_k,
def run_large_context_topk_test(
batch_size: int,
seq_lens: list[int],
top_k: int,
data_type: str = "random",
seed: int = 42,
) -> None:
"""
Helper to run persistent_topk kernel test with given parameters.
Args:
batch_size: Number of rows/sequences
seq_lens: List of sequence lengths (one per row)
top_k: Number of top elements to select
data_type: Type of test data to generate
seed: Random seed for reproducibility
"""
torch.set_default_device("cuda:0")
set_random_seed(seed)
# Create test data
num_rows = batch_size
max_len = max(seq_lens)
lengths = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
if data_type == "random":
logits = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
elif data_type == "sorted_asc":
# Each row gets its own ascending sequence based on its length
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
for i, length in enumerate(seq_lens):
logits[i, :length] = torch.arange(
length, dtype=torch.float32, device="cuda"
)
if length < max_len:
logits[i, length:] = float("-inf")
elif data_type == "sorted_desc":
# Each row gets its own descending sequence based on its length
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
for i, length in enumerate(seq_lens):
logits[i, :length] = torch.arange(
length, 0, -1, dtype=torch.float32, device="cuda"
)
if length < max_len:
logits[i, length:] = float("-inf")
elif data_type == "all_same":
logits = torch.ones(num_rows, max_len, dtype=torch.float32, device="cuda")
for i, length in enumerate(seq_lens):
if length < max_len:
logits[i, length:] = float("-inf")
elif data_type == "many_ties":
# Only 10 unique values, many duplicates
logits = torch.randint(0, 10, (num_rows, max_len), device="cuda").float() / 10.0
for i, length in enumerate(seq_lens):
if length < max_len:
logits[i, length:] = float("-inf")
elif data_type == "small_differences":
# Very small differences to test float precision
base = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
noise = (
torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") * 1e-6
)
logits = base + noise
for i, length in enumerate(seq_lens):
if length < max_len:
logits[i, length:] = float("-inf")
else:
raise ValueError(f"Unknown data_type: {data_type}")
# Create output tensor
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
max_seq_len = max(seq_lens)
torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max_seq_len
)
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
batch_size_long = 4
num_rows_long = batch_size_long * next_n
torch.accelerator.synchronize()
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
for i in range(num_rows):
length = seq_lens[i]
k_i = min(top_k, length)
if k_i > 0:
idx = logits[i, :length].topk(k_i, dim=-1)[1]
torch_indices[i, :k_i] = idx
if k_i < top_k:
torch_indices[i, k_i:] = -1
else:
torch_indices[i, :] = -1
# Compare results
for i in range(num_rows):
length = seq_lens[i]
k_i = min(top_k, length)
if k_i == 0:
continue
cuda_row = indices[i, :k_i].cpu()
torch_row = torch_indices[i, :k_i].cpu()
# Filter out -1 padding values from cuda_row
valid_mask = cuda_row >= 0
cuda_row = cuda_row[valid_mask]
# Compare sets (order may differ for ties)
cuda_set = set(cuda_row.tolist())
torch_set = set(torch_row.tolist())
if cuda_set == torch_set:
continue
# If sets differ, check if it's due to equal values (ties)
cuda_vals = logits[i, cuda_row].cpu()
torch_vals = logits[i, torch_row].cpu()
# Check that min CUDA value >= max of values NOT in top-k
if k_i < length:
non_topk_indices = torch.tensor(
list(set(range(length)) - cuda_set), dtype=torch.int32
)
if len(non_topk_indices) > 0:
non_topk_vals = logits[i, non_topk_indices].cpu()
min_cuda_val = cuda_vals.min()
max_non_topk = non_topk_vals.max()
# Allow small tolerance for floating point errors
assert min_cuda_val >= max_non_topk - 1e-4, (
f"Row {i}: CUDA top-k contains values smaller than non-top-k. "
f"Min CUDA: {min_cuda_val}, Max non-top-k: {max_non_topk}, "
f"Length: {length}, k: {k_i}, CUDA indices: {sorted(cuda_set)[:10]}..., " # noqa: E501
f"Expected indices: {sorted(torch_set)[:10]}..."
)
# For ties, verify the values are close
assert torch.allclose(
cuda_vals.sort(descending=True)[0],
torch_vals.sort(descending=True)[0],
rtol=1e-4,
atol=1e-4,
), f"""Row {i}: Top-k values don't match.
CUDA: {cuda_vals.sort(descending=True)[0][:10]},
Torch: {torch_vals.sort(descending=True)[0][:10]}"""
# Create sequences with max length >= 8192
seq_lens_long = torch.randint(
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize(
"test_config",
[
# ==================== CATEGORY: Sequence Length Edge Cases ====================
pytest.param(
{"seq_lens": [1, 10, 100, 2048], "top_k": 2048, "data_type": "random"},
id="seq_len_edge_very_small_to_medium",
),
pytest.param(
{
"seq_lens": [2049, 2100, 2500, 3000],
"top_k": 2048,
"data_type": "random",
},
id="seq_len_edge_above_k",
),
pytest.param(
{"seq_lens": [8000, 16384, 20000], "top_k": 2048, "data_type": "random"},
id="algo_transition_filtered_radix",
),
# ==================== CATEGORY: Data Distributions ====================
pytest.param(
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_asc"},
id="data_sorted_ascending",
),
pytest.param(
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_desc"},
id="data_sorted_descending",
),
pytest.param(
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "all_same"},
id="data_all_same",
),
pytest.param(
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "many_ties"},
id="data_many_ties",
),
pytest.param(
{
"seq_lens": [5000, 10000],
"top_k": 2048,
"data_type": "small_differences",
},
id="data_float_precision",
),
# ==================== CATEGORY: Alignment / Vectorization ====================
pytest.param(
{
"seq_lens": [2055, 2056, 2057, 2063],
"top_k": 2048,
"data_type": "random",
},
id="align_vec_boundaries_low",
),
pytest.param(
{
"seq_lens": [4095, 4096, 4097, 4102],
"top_k": 2048,
"data_type": "random",
},
id="align_4k_boundary",
),
pytest.param(
{
"seq_lens": [8191, 8192, 8193, 8198],
"top_k": 2048,
"data_type": "random",
},
id="align_8k_boundary",
),
pytest.param(
{
"seq_lens": [16383, 16384, 16385, 16390],
"top_k": 2048,
"data_type": "random",
},
id="align_16k_boundary",
),
],
)
@torch.inference_mode()
def test_persistent_topk_correctness(test_config: dict) -> None:
"""
Comprehensive correctness tests covering:
- Sequence length edge cases (trivial, boundary, varied)
- Very small sequences (< 100 elements)
- Mixed sequence lengths in same batch
- Data distributions (sorted, ties, precision)
- Memory alignment / vectorization boundaries
"""
run_large_context_topk_test(
batch_size=len(test_config["seq_lens"]),
seq_lens=test_config["seq_lens"],
top_k=test_config["top_k"],
data_type=test_config.get("data_type", "random"),
)
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
logits_long = create_random_logits(
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize(
"test_config",
[
# ==================== CATEGORY: Batch Size Scalability ====================
pytest.param(
{"batch_size": 1, "seq_len": 5000, "top_k": 2048},
id="batch_1",
),
pytest.param(
{"batch_size": 4, "seq_len": 5000, "top_k": 2048},
id="batch_4",
),
pytest.param(
{"batch_size": 32, "seq_len": 5000, "top_k": 2048},
id="batch_32",
),
pytest.param(
{"batch_size": 256, "seq_len": 5000, "top_k": 2048},
id="batch_256",
),
# ==================== CATEGORY: Single-CTA vs Multi-CTA ====================
pytest.param(
{"batch_size": 2, "seq_len": 4096, "top_k": 2048},
id="single_cta_4k",
),
pytest.param(
{"batch_size": 2, "seq_len": 8192, "top_k": 2048},
id="single_cta_8k",
),
pytest.param(
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
id="multi_cta_163840_dsv3_max",
),
# ==================== CATEGORY: Extreme Cases ====================
pytest.param(
{"batch_size": 512, "seq_len": 5000, "top_k": 2048},
id="extreme_large_batch",
),
pytest.param(
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
id="extreme_dsv3_max_context",
),
],
)
@torch.inference_mode()
def test_persistent_topk_algorithm_paths(test_config: dict) -> None:
"""
Test different algorithm execution paths (capped at 163840 for DeepSeek V3.2):
- Batch size scalability (1, 4, 32, 256)
- Single-CTA vs Multi-CTA execution
- Extreme configurations (large batch, max context length)
"""
run_large_context_topk_test(
batch_size=test_config["batch_size"],
seq_lens=[test_config["seq_len"]] * test_config["batch_size"],
top_k=test_config["top_k"],
)
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
# Use large_context_topk kernel for long sequences
if next_n == 1:
lengths = seq_lens_long
else:
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_persistent_topk_stress() -> None:
"""
Stress test with random configurations to catch edge cases.
Capped at 163840 (DeepSeek V3.2 max context) for realistic testing.
"""
torch.set_default_device("cuda:0")
top_k = 2048
torch.ops._C.large_context_topk(
logits_long,
indices,
lengths,
None,
for seed in range(3):
set_random_seed(seed)
# Random batch size (limited for speed)
batch_size = torch.randint(1, 32, (1,)).item()
# Random sequence lengths capped at DeepSeek V3.2 max context
seq_lens = torch.randint(100, 163840, (batch_size,)).tolist()
run_large_context_topk_test(
batch_size=batch_size,
seq_lens=seq_lens,
top_k=top_k,
seed=seed,
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize(
"test_config",
[
# Mixed batch: rows spanning all four paths (trivial, decode, medium, large)
pytest.param(
{
"seq_lens": [2000, 6000, 30000, 80000],
"top_k": 2048,
"data_type": "random",
},
id="mixed_all_paths",
),
# All decode/medium rows (typical decode scenario)
pytest.param(
{
"seq_lens": [2048, 4096, 8192, 16000],
"top_k": 2048,
"data_type": "random",
},
id="all_decode_medium",
),
# All large rows
pytest.param(
{
"seq_lens": [70000, 100000, 163840],
"top_k": 2048,
"data_type": "random",
},
id="all_large",
),
# Boundary around LARGE_THRESHOLD (32K)
pytest.param(
{
"seq_lens": [32767, 32768, 32769, 32772],
"top_k": 2048,
"data_type": "random",
},
id="large_threshold_boundary",
),
# Single row medium
pytest.param(
{
"seq_lens": [5000],
"top_k": 2048,
"data_type": "random",
},
id="single_row_medium",
),
# Single row large
pytest.param(
{
"seq_lens": [100000],
"top_k": 2048,
"data_type": "random",
},
id="single_row_large",
),
# Trivial rows mixed with medium and large
pytest.param(
{
"seq_lens": [100, 2048, 10000, 80000],
"top_k": 2048,
"data_type": "random",
},
id="trivial_medium_large_mix",
),
],
)
@torch.inference_mode()
def test_persistent_topk(test_config: dict) -> None:
"""
Tests specific to the persistent_topk kernel:
- Mixed medium/large rows in the same batch (dynamic per-row dispatch)
- Boundary around LARGE_THRESHOLD (32K)
- Trivial + medium + large rows in a single batch
"""
run_large_context_topk_test(
batch_size=len(test_config["seq_lens"]),
seq_lens=test_config["seq_lens"],
top_k=test_config["top_k"],
data_type=test_config.get("data_type", "random"),
)
torch_indices_short = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_persistent_topk_padded_stride() -> None:
"""
Test persistent_topk with padded logits (large stride, small seq_len)
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
returns [B, max_model_len] with max_model_len=163840.
"""
set_random_seed(42)
torch.set_default_device("cuda:0")
top_k = 2048
batch_size = 4
padded_stride = 163840 # DeepSeek-V3.2 max_model_len
actual_seq_lens = [3000, 5000, 8000, 12000]
# Create padded logits tensor (like fp8_paged_mqa_logits output)
logits = torch.full(
(batch_size, padded_stride),
float("-inf"),
dtype=torch.float32,
device="cuda",
)
for i in range(num_rows_short):
row_end = int(row_ends_short[i])
k_i = min(top_k, row_end)
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_short[i, :k_i] = idx
for i, sl in enumerate(actual_seq_lens):
logits[i, :sl] = torch.randn(sl, dtype=torch.float32, device="cuda")
assert compare_top_k_results(
logits_short,
indices_vllm,
torch_indices_short,
row_starts_short,
row_ends_short,
top_k,
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda")
indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda")
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
torch_indices_long = torch.empty(
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max(actual_seq_lens)
)
for i in range(num_rows_long):
row_end = int(row_ends_long[i])
k_i = min(top_k, row_end)
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_long[i, :k_i] = idx
torch.accelerator.synchronize()
assert compare_top_k_results(
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
), "large_context_topk kernel (long sequences) doesn't match torch.topk"
# Validate against torch.topk
for i in range(batch_size):
sl = actual_seq_lens[i]
k_i = min(top_k, sl)
expected = logits[i, :sl].topk(k_i, dim=-1)[1].cpu()
actual = indices[i, :k_i].cpu()
expected_set = set(expected.tolist())
actual_set = set(actual.tolist())
if expected_set != actual_set:
# Allow ties
expected_vals = logits[i, expected].cpu().sort(descending=True)[0]
actual_vals = logits[i, actual].cpu().sort(descending=True)[0]
assert torch.allclose(expected_vals, actual_vals, rtol=1e-4, atol=1e-4), (
f"Row {i}: persistent_topk with padded stride doesn't match. "
f"seq_len={sl}, stride={padded_stride}"
)
......@@ -25,6 +25,8 @@ elif current_platform.is_xpu():
logger = init_logger(__name__)
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def sparse_attn_indexer(
hidden_states: torch.Tensor,
......@@ -51,6 +53,7 @@ def sparse_attn_indexer(
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
# Dummy allocation to simulate for peak logits tensor memory during inference.
......@@ -157,15 +160,6 @@ def sparse_attn_indexer(
topk_tokens,
)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# torch.ops._C.large_context_topk(
# logits,
# topk_indices,
# lengths,
# chunk.cu_seqlen_ks, # row_starts
# )
if has_decode:
decode_metadata = attn_metadata.decode
assert decode_metadata is not None
......@@ -204,23 +198,29 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if decode_metadata.use_large_context_topk:
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
torch.ops._C.large_context_topk(
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
if current_platform.is_cuda():
workspace_manager = current_workspace_manager()
(topk_workspace,) = workspace_manager.get_simultaneous(
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
torch.ops._C.persistent_topk(
logits,
topk_indices,
lengths,
None,
topk_indices,
topk_workspace,
topk_tokens,
attn_metadata.max_seq_len,
)
else:
if current_platform.is_xpu():
......
......@@ -67,7 +67,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.model_executor.layers.sparse_attn_indexer import (
SparseAttnIndexer,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -1203,7 +1205,9 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
vllm_config,
prefix,
topk_indices_buffer=topk_indices_buffer,
),
prefix=f"{prefix}.layers",
)
......
......@@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata:
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
......@@ -437,7 +436,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if use_native and next_n > 1:
offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
......@@ -496,10 +494,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
......@@ -509,20 +505,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=seq_lens,
decode_lens=decode_lens,
requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment