Unverified Commit bf9a5ddb authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[MLA] Optimize mla indexer prepare uniform decode for MTP > 1 (#39458)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 79e799eb
......@@ -8,6 +8,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
has_deep_gemm,
......@@ -30,6 +31,40 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@triton.jit
def _prepare_uniform_decode_kernel(
seq_lens_ptr,
decode_seq_lens_ptr,
block_table_ptr,
block_table_stride,
expanded_block_table_ptr,
expanded_bt_stride,
decode_lens_ptr,
max_decode_len,
BLOCK_SIZE: tl.constexpr,
):
idx = tl.program_id(0)
req_id = idx // max_decode_len
local_idx = idx % max_decode_len
# Compute number of KVs attended to by this token.
seq_len = tl.load(seq_lens_ptr + req_id)
per_token_seq_len = seq_len - max_decode_len + local_idx + 1
tl.store(decode_seq_lens_ptr + idx, per_token_seq_len)
# Copy block table row.
src = block_table_ptr + req_id * block_table_stride
dst = expanded_block_table_ptr + idx * expanded_bt_stride
for i in tl.range(0, expanded_bt_stride, BLOCK_SIZE):
off = i + tl.arange(0, BLOCK_SIZE)
mask = off < expanded_bt_stride
src_block = tl.load(src + off, mask=mask)
tl.store(dst + off, src_block, mask=mask)
# All reqs now have decode_len = 1.
tl.store(decode_lens_ptr + idx, 1)
def split_indexer_prefill_chunks(
seq_lens_cpu: torch.Tensor,
query_lens_cpu: torch.Tensor,
......@@ -405,52 +440,75 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
"""
min_decode_len = int(decode_lens_cpu.min().item())
if not use_native and max_decode_len > 1:
assert self.decode_seq_lens_buffer.dim() == 1
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# Fuse expanded_base and expanded_starts into a single repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets = torch.repeat_interleave(
seq_lens - decode_lens - query_start_loc,
decode_lens,
output_size=actual_expanded,
)
if min_decode_len == max_decode_len:
# Uniform decode lengths.
num_decode_tokens = num_decodes * max_decode_len
_prepare_uniform_decode_kernel[(num_decode_tokens,)](
seq_lens,
self.decode_seq_lens_buffer,
block_table,
block_table.stride(0),
self.expanded_block_table_buffer,
self.expanded_block_table_buffer.stride(0),
self.decode_lens_buffer,
max_decode_len,
BLOCK_SIZE=1024,
)
self.decode_seq_lens_buffer[num_decode_tokens:] = 0
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
return seq_lens, block_table, decode_lens, num_decode_tokens, False
else:
# Variable decode lengths.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# Fuse expanded_base and expanded_starts into a single
# repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets = torch.repeat_interleave(
seq_lens - decode_lens - query_start_loc,
decode_lens,
output_size=actual_expanded,
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.decode_seq_lens_buffer[:actual_expanded] = (
expanded_offsets + self.arange_buffer[:actual_expanded] + 1
)
self.decode_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.decode_seq_lens_buffer[:actual_expanded] = (
expanded_offsets + self.arange_buffer[:actual_expanded] + 1
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
return seq_lens, block_table, decode_lens, num_decode_tokens, False
self.decode_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
return seq_lens, block_table, decode_lens, num_decode_tokens, False
else:
# Native path: plain decode (next_n==1) or spec decode
# with 2D per-token context lengths (next_n > 1).
......@@ -459,7 +517,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# decode_len < next_n due to padding or short prefills), the simple
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton
# (requires_padding) instead.
min_decode_len = int(decode_lens_cpu.min().item())
requires_padding = min_decode_len != max_decode_len
if use_native and next_n > 1:
assert self.decode_seq_lens_buffer.dim() == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment