"vscode:/vscode.git/clone" did not exist on "f7db5f0fa9db2ea5680e373fcb1b21fb0c32797e"
Unverified Commit b55d830e authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with...


[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode (#37421)
Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 75e01a39
......@@ -18,10 +18,9 @@ steps:
source_file_dependencies:
- csrc/
- tests/kernels/core
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py
commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
- pytest -v -s kernels/core kernels/test_concat_mla_q.py
- label: Kernels Attention Test %N
timeout_in_minutes: 35
......@@ -107,6 +106,7 @@ steps:
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/v1/attention/selector.py
- vllm/platforms/cuda.py
- tests/kernels/test_top_k_per_row.py
commands:
- nvidia-smi
- python3 examples/basic/offline_inference/chat.py
......@@ -117,6 +117,7 @@ steps:
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
- pytest -v -s tests/kernels/test_top_k_per_row.py
# Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
......
......@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
const torch::Tensor& lengths,
std::optional<torch::Tensor> row_starts_opt);
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
......
This diff is collapsed.
This diff is collapsed.
......@@ -197,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
"Tensor? "
"row_starts_opt) -> ()");
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor workspace, int k, int max_seq_len) -> ()");
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
......
This diff is collapsed.
......@@ -25,6 +25,8 @@ elif current_platform.is_xpu():
logger = init_logger(__name__)
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def sparse_attn_indexer(
hidden_states: torch.Tensor,
......@@ -51,6 +53,7 @@ def sparse_attn_indexer(
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
# Dummy allocation to simulate for peak logits tensor memory during inference.
......@@ -157,15 +160,6 @@ def sparse_attn_indexer(
topk_tokens,
)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# torch.ops._C.large_context_topk(
# logits,
# topk_indices,
# lengths,
# chunk.cu_seqlen_ks, # row_starts
# )
if has_decode:
decode_metadata = attn_metadata.decode
assert decode_metadata is not None
......@@ -204,7 +198,6 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if decode_metadata.use_large_context_topk:
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
......@@ -216,11 +209,18 @@ def sparse_attn_indexer(
+ decode_metadata.offsets
).flatten()
torch.ops._C.large_context_topk(
if current_platform.is_cuda():
workspace_manager = current_workspace_manager()
(topk_workspace,) = workspace_manager.get_simultaneous(
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
torch.ops._C.persistent_topk(
logits,
topk_indices,
lengths,
None,
topk_indices,
topk_workspace,
topk_tokens,
attn_metadata.max_seq_len,
)
else:
if current_platform.is_xpu():
......
......@@ -67,7 +67,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.model_executor.layers.sparse_attn_indexer import (
SparseAttnIndexer,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -1203,7 +1205,9 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
vllm_config,
prefix,
topk_indices_buffer=topk_indices_buffer,
),
prefix=f"{prefix}.layers",
)
......
......@@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata:
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
......@@ -437,7 +436,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if use_native and next_n > 1:
offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
......@@ -496,10 +494,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
......@@ -509,20 +505,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=seq_lens,
decode_lens=decode_lens,
requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment