Unverified Commit c88ea833 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[MTP][Sparse MLA] Take advantage of native MTP support in indexer when possible (#36982)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9f9ecff4
......@@ -575,7 +575,7 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
......
......@@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
reorder_batch_threshold: int = 1
natively_supported_next_n: list[int] = [1, 2]
# TODO (matt): integrate kernel with next_n = 4 support
@classmethod
def get_cudagraph_support(
......@@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
next_n = self.num_speculative_tokens + 1
self.reorder_batch_threshold += self.num_speculative_tokens
self.use_flattening = next_n not in self.natively_supported_next_n
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
......@@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.offsets_buffer = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
scheduler_config.max_num_seqs * next_n,
dtype=torch.int32,
device=self.device,
)
......@@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=not self.use_flattening,
)
)
......@@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n
if use_native and next_n > 1:
offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Also handles the edge case where use_flattening=False
# but max_decode_len != next_n (e.g. a batch containing some
# short prefills (q_len < next_n) and no true decodes).
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
......@@ -428,13 +445,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment