Unverified Commit bf9a5ddb authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[MLA] Optimize mla indexer prepare uniform decode for MTP > 1 (#39458)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 79e799eb
...@@ -8,6 +8,7 @@ import vllm.envs as envs ...@@ -8,6 +8,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata, get_paged_mqa_logits_metadata,
has_deep_gemm, has_deep_gemm,
...@@ -30,6 +31,40 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size ...@@ -30,6 +31,40 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def _prepare_uniform_decode_kernel(
seq_lens_ptr,
decode_seq_lens_ptr,
block_table_ptr,
block_table_stride,
expanded_block_table_ptr,
expanded_bt_stride,
decode_lens_ptr,
max_decode_len,
BLOCK_SIZE: tl.constexpr,
):
idx = tl.program_id(0)
req_id = idx // max_decode_len
local_idx = idx % max_decode_len
# Compute number of KVs attended to by this token.
seq_len = tl.load(seq_lens_ptr + req_id)
per_token_seq_len = seq_len - max_decode_len + local_idx + 1
tl.store(decode_seq_lens_ptr + idx, per_token_seq_len)
# Copy block table row.
src = block_table_ptr + req_id * block_table_stride
dst = expanded_block_table_ptr + idx * expanded_bt_stride
for i in tl.range(0, expanded_bt_stride, BLOCK_SIZE):
off = i + tl.arange(0, BLOCK_SIZE)
mask = off < expanded_bt_stride
src_block = tl.load(src + off, mask=mask)
tl.store(dst + off, src_block, mask=mask)
# All reqs now have decode_len = 1.
tl.store(decode_lens_ptr + idx, 1)
def split_indexer_prefill_chunks( def split_indexer_prefill_chunks(
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
query_lens_cpu: torch.Tensor, query_lens_cpu: torch.Tensor,
...@@ -405,52 +440,75 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -405,52 +440,75 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding). Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP. seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
""" """
min_decode_len = int(decode_lens_cpu.min().item())
if not use_native and max_decode_len > 1: if not use_native and max_decode_len > 1:
assert self.decode_seq_lens_buffer.dim() == 1 assert self.decode_seq_lens_buffer.dim() == 1
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is if min_decode_len == max_decode_len:
# padding) and decode_lens [3, 1, 4, 0] in the below example comments. # Uniform decode lengths.
# The context lengths are therefore num_decode_tokens = num_decodes * max_decode_len
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0]. _prepare_uniform_decode_kernel[(num_decode_tokens,)](
seq_lens,
# 3 + 1 + 4 + 0 = 8 self.decode_seq_lens_buffer,
actual_expanded = int(decode_lens_cpu.sum().item()) block_table,
block_table.stride(0),
# Fuse expanded_base and expanded_starts into a single repeat_interleave: self.expanded_block_table_buffer,
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1 self.expanded_block_table_buffer.stride(0),
# where context_start[b] = seq_lens[b] - decode_lens[b]. self.decode_lens_buffer,
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8] max_decode_len,
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4] BLOCK_SIZE=1024,
# result = [8, 9, 10, 7, 9, 10, 11, 12] )
expanded_offsets = torch.repeat_interleave( self.decode_seq_lens_buffer[num_decode_tokens:] = 0
seq_lens - decode_lens - query_start_loc, seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
decode_lens, block_table = self.expanded_block_table_buffer[:num_decode_tokens]
output_size=actual_expanded, decode_lens = self.decode_lens_buffer[:num_decode_tokens]
) return seq_lens, block_table, decode_lens, num_decode_tokens, False
else:
# Variable decode lengths.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# Fuse expanded_base and expanded_starts into a single
# repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets = torch.repeat_interleave(
seq_lens - decode_lens - query_start_loc,
decode_lens,
output_size=actual_expanded,
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.decode_seq_lens_buffer[:actual_expanded] = ( self.decode_seq_lens_buffer[:actual_expanded] = (
expanded_offsets + self.arange_buffer[:actual_expanded] + 1 expanded_offsets + self.arange_buffer[:actual_expanded] + 1
)
self.decode_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
) )
) self.decode_seq_lens_buffer[actual_expanded:] = 0
if actual_expanded < num_decode_tokens: seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0 # Give each of the flattened entries the same block table row as the
] = 0 # original request.
block_table = self.expanded_block_table_buffer[:num_decode_tokens] self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
# All reqs now have decode_len=1 block_table, decode_lens, dim=0, output_size=actual_expanded
self.decode_lens_buffer[:num_decode_tokens] = 1 )
decode_lens = self.decode_lens_buffer[:num_decode_tokens] )
return seq_lens, block_table, decode_lens, num_decode_tokens, False if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
return seq_lens, block_table, decode_lens, num_decode_tokens, False
else: else:
# Native path: plain decode (next_n==1) or spec decode # Native path: plain decode (next_n==1) or spec decode
# with 2D per-token context lengths (next_n > 1). # with 2D per-token context lengths (next_n > 1).
...@@ -459,7 +517,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -459,7 +517,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# decode_len < next_n due to padding or short prefills), the simple # decode_len < next_n due to padding or short prefills), the simple
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton # reshape in sparse_attn_indexer won't work. Use pack_seq_triton
# (requires_padding) instead. # (requires_padding) instead.
min_decode_len = int(decode_lens_cpu.min().item())
requires_padding = min_decode_len != max_decode_len requires_padding = min_decode_len != max_decode_len
if use_native and next_n > 1: if use_native and next_n > 1:
assert self.decode_seq_lens_buffer.dim() == 2 assert self.decode_seq_lens_buffer.dim() == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment