Unverified Commit afdce12c authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)


Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
parent 82e11973
......@@ -293,6 +293,7 @@ set(VLLM_EXT_SRC
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu"
......
......@@ -114,6 +114,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
const torch::Tensor& lengths,
std::optional<torch::Tensor> row_starts_opt);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
......
// Portions of this file are adapted from SGLang PR:
// https://github.com/sgl-project/sglang/pull/11194
// and
// https://github.com/sgl-project/sglang/pull/17747
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
constexpr int kThreadsPerBlock = 1024; // Threads per block
// Shared memory budget
#if defined(USE_ROCM)
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
#else
// Reduced from 128KB to 32KB to improve occupancy.
// Each radix pass needs at most ~TopK candidates in the threshold bin,
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
#endif
struct FastTopKParams {
const float* __restrict__ input; // [batch, seq_len] Logits
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
// (optional)
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
int64_t input_stride; // Stride between rows
};
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ void naive_topk_cuda(const float* __restrict__ logits,
int32_t* __restrict__ output_indices,
int32_t seq_len) {
const int thread_id = threadIdx.x;
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
output_indices[i] = (i < seq_len) ? i : -1;
}
}
// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ void fast_topk_cuda_tl(
const float* __restrict__ logits, // Input logits [seq_len]
int* __restrict__ output_indices, // Output top-k indices [TopK]
int logits_offset, // Starting offset in logits array
int seq_len) // Number of valid logits to process
{
constexpr int RADIX = 256;
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
alignas(128) __shared__ int shared_output_count;
alignas(128) __shared__ int shared_threshold_bin;
alignas(128) __shared__ int shared_buffered_count[2];
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
const int thread_id = threadIdx.x;
int remaining_k = TopK;
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
::atomicAdd(&shared_histogram[0][bin], 1);
}
__syncthreads();
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
// buffers
auto compute_cumulative_sum = [&]() {
static_assert(1 << 8 == RADIX,
"Radix must be 256 for 8 unrolled iterations");
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (C10_LIKELY(thread_id < RADIX)) {
const int stride = 1 << i;
const int src_buffer = i & 1;
const int dst_buffer = src_buffer ^ 1;
int value = shared_histogram[src_buffer][thread_id];
if (thread_id < RADIX - stride) {
value += shared_histogram[src_buffer][thread_id + stride];
}
shared_histogram[dst_buffer][thread_id] = value;
}
__syncthreads();
}
};
compute_cumulative_sum();
// Find threshold bin where cumsum crosses remaining_k
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[0] = 0;
shared_output_count = 0;
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Early exit if threshold bin perfectly matches remaining_k
if (remaining_k == 0) {
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const int bin = convert_to_uint8(logits[idx + logits_offset]);
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
return;
}
// Prepare for refinement passes: Process threshold bin
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
// Scan all elements and:
// 1. Write indices > threshold_bin to output
// 2. Buffer indices == threshold_bin for refinement
// 3. Build histogram for next refinement pass (fused optimization)
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const float logit_value = logits[idx + logits_offset];
const int bin = convert_to_uint8(logit_value);
if (bin > threshold_bin) {
// in top-k, write to output
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
// Candidate for top-k, needs refinement
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[0][buffer_pos] = idx;
// Fused: Build histogram for next pass
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int next_bin = (fp32_bits >> 24) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
__syncthreads();
// ============================================================================
// Passes 1-4: Refine using 8-bit passes over FP32 bits
// ============================================================================
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
// bits [7:0]
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
__shared__ int shared_final_k; // For final pass: remaining slots to fill
const int src_buffer = pass % 2;
const int dst_buffer = src_buffer ^ 1;
// Clamp buffered count to prevent overflow
const int raw_buffered = shared_buffered_count[src_buffer];
const int num_buffered =
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
compute_cumulative_sum();
// Find threshold bin for this pass
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[dst_buffer] = 0;
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Bit offset for this pass: 24, 16, 8, 0
const int bit_offset = 24 - pass * 8;
// Early exit if threshold bin perfectly matches
if (remaining_k == 0) {
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const uint32_t fp32_bits =
convert_to_uint32_v2(logits[idx + logits_offset]);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
break;
}
// Continue refinement
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const float logit_value = logits[idx + logits_offset];
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
// Definitely in top-k
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
if (pass == 3) {
// Final pass (bits [7:0]): No more refinement possible
// Fill remaining slots in reverse order to maintain descending order
const int slot = ::atomicAdd(&shared_final_k, -1);
if (slot > 0) {
output_indices[TopK - slot] = idx;
}
} else {
// Buffer for next pass and build next histogram
const int buffer_pos =
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[dst_buffer][buffer_pos] = idx;
// Fused: Build histogram for next pass
const int next_bit_offset = bit_offset - 8;
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
}
__syncthreads();
}
}
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
const FastTopKParams params) {
const auto& [input, row_starts, indices, lengths, input_stride] = params;
const uint64_t batch_idx = blockIdx.x;
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
const int seq_len = lengths[batch_idx];
int* output_indices = indices + batch_idx * TopK;
const float* logits = input + batch_idx * input_stride;
if (seq_len <= TopK) {
// Shortcut: All elements are in top-k
return naive_topk_cuda(logits, output_indices, seq_len);
} else {
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
}
}
FastTopKParams get_params(
const at::Tensor& score, const at::Tensor& lengths,
std::optional<at::Tensor> row_starts_opt = std::nullopt,
std::optional<at::Tensor> indices_opt = std::nullopt) {
const int64_t batch_size = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
"score must be 2D with contiguous rows");
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
lengths.size(0) == batch_size,
"lengths must be 1D contiguous with size matching batch");
const int32_t* row_starts_ptr = nullptr;
if (row_starts_opt.has_value()) {
const auto& row_starts = *row_starts_opt;
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
"row_starts must be 1D with size matching batch");
row_starts_ptr = row_starts.data_ptr<int32_t>();
}
int32_t* indices_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = *indices_opt;
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
indices.size(0) == batch_size && indices.size(1) == TopK,
"indices must be 2D contiguous [batch, TopK]");
indices_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.row_starts = row_starts_ptr,
.indices = indices_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* kernel_func, size_t smem_bytes>
void setup_kernel_smem_once() {
static const cudaError_t result = []() -> cudaError_t {
#ifdef USE_ROCM
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
#else
auto func_ptr = kernel_func;
#endif
return cudaFuncSetAttribute(
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}();
TORCH_CHECK(
result == cudaSuccess,
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
}
} // namespace vllm
void large_context_topk(
const torch::Tensor& logits, torch::Tensor& indices,
const torch::Tensor& seq_lens,
c10::optional<torch::Tensor> row_starts = c10::nullopt) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
if (row_starts.has_value()) {
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
}
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
const int64_t batch_size = logits.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const dim3 grid(static_cast<uint32_t>(batch_size));
const dim3 block(vllm::kThreadsPerBlock);
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
const cudaError_t result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"large_context_topk kernel failed: ", cudaGetErrorString(result));
}
\ No newline at end of file
......@@ -190,6 +190,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
"Tensor? "
"row_starts_opt) -> ()");
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......
......@@ -275,3 +275,114 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize("clean_logits", [True, False])
@torch.inference_mode()
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
torch.set_default_device("cuda:0")
top_k = 2048
# Test case 1: Short sequences (< 8192)
batch_size_short = 4
next_n = 1
num_rows_short = batch_size_short * next_n
# Create sequences with max length < 8192
seq_lens_short = torch.randint(
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
)
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
row_ends_short = (
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
)
logits_short = create_random_logits(
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
)
indices_vllm = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
)
# Use vllm's kernel for short sequences
torch.ops._C.top_k_per_row_decode(
logits_short,
next_n,
seq_lens_short,
indices_vllm,
num_rows_short,
logits_short.stride(0),
logits_short.stride(1),
top_k,
)
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
batch_size_long = 4
num_rows_long = batch_size_long * next_n
# Create sequences with max length >= 8192
seq_lens_long = torch.randint(
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
)
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
logits_long = create_random_logits(
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
)
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
# Use large_context_topk kernel for long sequences
if next_n == 1:
lengths = seq_lens_long
else:
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
torch.ops._C.large_context_topk(
logits_long,
indices,
lengths,
None,
)
torch_indices_short = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
)
for i in range(num_rows_short):
row_end = int(row_ends_short[i])
k_i = min(top_k, row_end)
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_short[i, :k_i] = idx
assert compare_top_k_results(
logits_short,
indices_vllm,
torch_indices_short,
row_starts_short,
row_ends_short,
top_k,
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
torch_indices_long = torch.empty(
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
)
for i in range(num_rows_long):
row_end = int(row_ends_long[i])
k_i = min(top_k, row_end)
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_long[i, :k_i] = idx
assert compare_top_k_results(
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
), "large_context_topk kernel (long sequences) doesn't match torch.topk"
......@@ -126,6 +126,15 @@ def sparse_attn_indexer(
topk_tokens,
)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# torch.ops._C.large_context_topk(
# logits,
# topk_indices,
# lengths,
# chunk.cu_seqlen_ks, # row_starts
# )
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
......@@ -162,8 +171,27 @@ def sparse_attn_indexer(
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if decode_metadata.use_large_context_topk:
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
torch.ops._C.large_context_topk(
logits,
topk_indices,
lengths,
None,
)
else:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
......
......@@ -86,6 +86,8 @@ class DeepSeekV32IndexerDecodeMetadata:
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
@dataclass
......@@ -320,6 +322,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
......@@ -331,6 +348,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
)
attn_metadata = DeepseekV32IndexerMetadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment