Unverified Commit b55d830e authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with...


[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode (#37421)
Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 75e01a39
...@@ -18,10 +18,9 @@ steps: ...@@ -18,10 +18,9 @@ steps:
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- tests/kernels/core - tests/kernels/core
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py - tests/kernels/test_concat_mla_q.py
commands: commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py - pytest -v -s kernels/core kernels/test_concat_mla_q.py
- label: Kernels Attention Test %N - label: Kernels Attention Test %N
timeout_in_minutes: 35 timeout_in_minutes: 35
...@@ -107,6 +106,7 @@ steps: ...@@ -107,6 +106,7 @@ steps:
- vllm/v1/attention/backends/mla/flashinfer_mla.py - vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/v1/attention/selector.py - vllm/v1/attention/selector.py
- vllm/platforms/cuda.py - vllm/platforms/cuda.py
- tests/kernels/test_top_k_per_row.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/basic/offline_inference/chat.py - python3 examples/basic/offline_inference/chat.py
...@@ -117,6 +117,7 @@ steps: ...@@ -117,6 +117,7 @@ steps:
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
- pytest -v -s tests/kernels/test_top_k_per_row.py
# Quantization # Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
......
...@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, ...@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t numRows, int64_t stride0, int64_t stride1, int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK); int64_t topK);
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices, void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
const torch::Tensor& lengths, torch::Tensor& output, torch::Tensor& workspace, int64_t k,
std::optional<torch::Tensor> row_starts_opt); int64_t max_seq_len);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, torch::Tensor& weight, torch::Tensor& scale,
......
This diff is collapsed.
This diff is collapsed.
...@@ -197,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -197,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def( ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, " "persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor? " "Tensor workspace, int k, int max_seq_len) -> ()");
"row_starts_opt) -> ()"); ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
......
This diff is collapsed.
...@@ -25,6 +25,8 @@ elif current_platform.is_xpu(): ...@@ -25,6 +25,8 @@ elif current_platform.is_xpu():
logger = init_logger(__name__) logger = init_logger(__name__)
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -51,6 +53,7 @@ def sparse_attn_indexer( ...@@ -51,6 +53,7 @@ def sparse_attn_indexer(
current_workspace_manager().get_simultaneous( current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn), ((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, 4), torch.uint8),
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
) )
# Dummy allocation to simulate for peak logits tensor memory during inference. # Dummy allocation to simulate for peak logits tensor memory during inference.
...@@ -157,15 +160,6 @@ def sparse_attn_indexer( ...@@ -157,15 +160,6 @@ def sparse_attn_indexer(
topk_tokens, topk_tokens,
) )
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# torch.ops._C.large_context_topk(
# logits,
# topk_indices,
# lengths,
# chunk.cu_seqlen_ks, # row_starts
# )
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = attn_metadata.decode
assert decode_metadata is not None assert decode_metadata is not None
...@@ -204,7 +198,6 @@ def sparse_attn_indexer( ...@@ -204,7 +198,6 @@ def sparse_attn_indexer(
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if decode_metadata.use_large_context_topk:
if next_n == 1: if next_n == 1:
lengths = decode_metadata.seq_lens lengths = decode_metadata.seq_lens
else: else:
...@@ -216,11 +209,18 @@ def sparse_attn_indexer( ...@@ -216,11 +209,18 @@ def sparse_attn_indexer(
+ decode_metadata.offsets + decode_metadata.offsets
).flatten() ).flatten()
torch.ops._C.large_context_topk( if current_platform.is_cuda():
workspace_manager = current_workspace_manager()
(topk_workspace,) = workspace_manager.get_simultaneous(
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
torch.ops._C.persistent_topk(
logits, logits,
topk_indices,
lengths, lengths,
None, topk_indices,
topk_workspace,
topk_tokens,
attn_metadata.max_seq_len,
) )
else: else:
if current_platform.is_xpu(): if current_platform.is_xpu():
......
...@@ -67,7 +67,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -67,7 +67,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.model_executor.layers.sparse_attn_indexer import (
SparseAttnIndexer,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -1203,7 +1205,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1203,7 +1205,9 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer vllm_config,
prefix,
topk_indices_buffer=topk_indices_buffer,
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
......
...@@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata: ...@@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata:
decode_lens: torch.Tensor decode_lens: torch.Tensor
requires_padding: bool requires_padding: bool
schedule_metadata: torch.Tensor schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
...@@ -437,7 +436,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -437,7 +436,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if use_native and next_n > 1: if use_native and next_n > 1:
offsets = self.offsets_buffer offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1: elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token # Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so # batch entries, expanding seq_lens and block tables so
...@@ -496,10 +494,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -496,10 +494,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.decode_lens_buffer[:num_decode_tokens] = 1 self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens] decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None offsets = None
batch_size = num_decode_tokens
else: else:
offsets = None offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices # DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm(): if current_platform.is_cuda() and has_deep_gemm():
...@@ -509,20 +505,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -509,20 +505,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.num_sms, self.num_sms,
) )
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
decode_metadata = DeepSeekV32IndexerDecodeMetadata( decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table, block_table=block_table,
seq_lens=seq_lens, seq_lens=seq_lens,
decode_lens=decode_lens, decode_lens=decode_lens,
requires_padding=False, requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer, schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets, offsets=offsets,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment