Unverified Commit d5b6e50f authored by yinghui's avatar yinghui Committed by GitHub
Browse files

perf: trtllm mla performance minor improvements (#12435)

parent 9632e48f
......@@ -219,6 +219,7 @@ class TRTLLMMLADecodeMetadata:
sum_seq_lens_q: Optional[int] = None
cu_seqlens_q: Optional[torch.Tensor] = None
seq_lens_q: Optional[torch.Tensor] = None
seq_lens_k: Optional[torch.Tensor] = None
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
......@@ -404,8 +405,38 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info,
)
metadata = TRTLLMMLADecodeMetadata()
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
metadata.seq_lens_k = torch.zeros(
(bs,), dtype=torch.int32, device=seq_lens.device
)
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
elif forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=seq_lens.device,
)
metadata.seq_lens_q = torch.full(
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
)
# NOTE(draft_extend seq_len handling):
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens = seq_lens - metadata.seq_lens_q + metadata.max_seq_len_q
metadata.seq_lens_k = torch.zeros(
(bs,), dtype=torch.int32, device=seq_lens.device
)
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
......@@ -423,24 +454,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size,
)
metadata = TRTLLMMLADecodeMetadata(
block_kv_indices,
self.max_context_len,
)
if forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs + 1
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=seq_lens.device,
)
metadata.seq_lens_q = torch.full(
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
)
metadata.block_kv_indices = block_kv_indices
metadata.max_seq_len_k = self.max_context_len
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_decode_metadata = metadata
......@@ -473,17 +489,17 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu,
)
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
metadata = self.decode_cuda_graph_metadata[bs]
if forward_mode.is_draft_extend(include_v2=True):
if forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
elif forward_mode.is_draft_extend(include_v2=True):
accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) + 1
metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs]) + bs
else:
metadata.max_seq_len_q = 1
metadata.sum_seq_lens_q = bs
......@@ -491,12 +507,15 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
metadata.seq_lens_q.copy_(accept_length)
# see NOTE(draft_extend seq_len handling)
seq_lens = seq_lens[:bs] - metadata.seq_lens_q + metadata.max_seq_len_q
metadata.seq_lens_k.copy_(seq_lens.to(torch.int32))
# Update block indices for new sequences.
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens,
None,
metadata.block_kv_indices,
self.req_to_token.stride(0),
......@@ -538,7 +557,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
bs = forward_batch.batch_size
self.forward_decode_metadata = TRTLLMMLADecodeMetadata()
# Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
......@@ -550,21 +569,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
)
max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
block_kv_indices, max_seq_len_val
)
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
self.forward_decode_metadata.seq_lens_k = seq_lens
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item()
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
......@@ -575,11 +581,26 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
),
(1, 0),
)
# see NOTE(draft_extend seq_len handling)
seq_lens = seq_lens - forward_batch.extend_seq_lens + max_seq_len_q
self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
self.forward_decode_metadata.seq_lens_k = seq_lens
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
)
self.forward_decode_metadata.block_kv_indices = block_kv_indices
self.forward_decode_metadata.max_seq_len_k = int(max_seq)
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
......@@ -899,18 +920,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
......@@ -936,23 +948,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
bmm1_scale = q_scale * k_scale * layer.scaling
if forward_batch.forward_mode.is_target_verify():
seq_lens = (
forward_batch.seq_lens.to(torch.int32)
+ forward_batch.spec_info.draft_token_num
)
max_seq_len = (
metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
)
else:
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens = (
forward_batch.seq_lens
- metadata.seq_lens_q
+ metadata.max_seq_len_q
).to(torch.int32)
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if self.padded_q_buffer is not None:
......@@ -986,7 +985,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=seq_lens,
seq_lens=metadata.seq_lens_k,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
)
......@@ -1003,6 +1002,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment