Unverified Commit c88ea833 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[MTP][Sparse MLA] Take advantage of native MTP support in indexer when possible (#36982)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9f9ecff4
...@@ -575,7 +575,7 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( ...@@ -575,7 +575,7 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row. // The range of logits within the row.
int rowStart = 0; int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n]; int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
// Local pointers to this block // Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) { if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
......
...@@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): ...@@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
reorder_batch_threshold: int = 1 reorder_batch_threshold: int = 1
natively_supported_next_n: list[int] = [1, 2]
# TODO (matt): integrate kernel with next_n = 4 support
@classmethod @classmethod
def get_cudagraph_support( def get_cudagraph_support(
...@@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config if self.vllm_config.speculative_config
else 0 else 0
) )
next_n = self.num_speculative_tokens + 1
self.reorder_batch_threshold += self.num_speculative_tokens self.reorder_batch_threshold += self.num_speculative_tokens
self.use_flattening = next_n not in self.natively_supported_next_n
sm_count = num_compute_units(self.device.index) sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count self.num_sms = sm_count
...@@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
self.offsets_buffer = torch.arange(
# Pre-allocated buffers for flattening (spec decode). next_n, device=self.device, dtype=torch.int32
)
self.arange_buffer = torch.arange( self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens), scheduler_config.max_num_seqs * next_n,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
...@@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=not self.use_flattening,
) )
) )
...@@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
block_table.clamp_(min=0) block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item()) max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1: next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n
if use_native and next_n > 1:
offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token # Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so # batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1. # the kernel always sees next_n=1.
# Also handles the edge case where use_flattening=False
# but max_decode_len != next_n (e.g. a batch containing some
# short prefills (q_len < next_n) and no true decodes).
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments. # padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore # The context lengths are therefore
...@@ -427,12 +444,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -427,12 +444,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
decode_lens = self.decode_lens_buffer[:num_decode_tokens] decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None offsets = None
batch_size = num_decode_tokens batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else: else:
offsets = None offsets = None
batch_size = num_decodes batch_size = num_decodes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment