Unverified Commit d5b6e50f authored by yinghui's avatar yinghui Committed by GitHub
Browse files

perf: trtllm mla performance minor improvements (#12435)

parent 9632e48f
...@@ -219,6 +219,7 @@ class TRTLLMMLADecodeMetadata: ...@@ -219,6 +219,7 @@ class TRTLLMMLADecodeMetadata:
sum_seq_lens_q: Optional[int] = None sum_seq_lens_q: Optional[int] = None
cu_seqlens_q: Optional[torch.Tensor] = None cu_seqlens_q: Optional[torch.Tensor] = None
seq_lens_q: Optional[torch.Tensor] = None seq_lens_q: Optional[torch.Tensor] = None
seq_lens_k: Optional[torch.Tensor] = None
class TRTLLMMLABackend(FlashInferMLAAttnBackend): class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...@@ -404,8 +405,38 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -404,8 +405,38 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info, spec_info,
) )
metadata = TRTLLMMLADecodeMetadata()
if forward_mode.is_target_verify(): if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens seq_lens = seq_lens + self.num_draft_tokens
metadata.seq_lens_k = torch.zeros(
(bs,), dtype=torch.int32, device=seq_lens.device
)
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
elif forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=seq_lens.device,
)
metadata.seq_lens_q = torch.full(
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
)
# NOTE(draft_extend seq_len handling):
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens = seq_lens - metadata.seq_lens_q + metadata.max_seq_len_q
metadata.seq_lens_k = torch.zeros(
(bs,), dtype=torch.int32, device=seq_lens.device
)
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
# Custom fast-path for decode/idle. # Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay # Capture with full width so future longer sequences are safe during replay
...@@ -423,24 +454,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -423,24 +454,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size, PAGED_SIZE=self.page_size,
) )
metadata = TRTLLMMLADecodeMetadata( metadata.block_kv_indices = block_kv_indices
block_kv_indices, metadata.max_seq_len_k = self.max_context_len
self.max_context_len,
)
if forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs + 1
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=seq_lens.device,
)
metadata.seq_lens_q = torch.full(
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
)
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
self.forward_decode_metadata = metadata self.forward_decode_metadata = metadata
...@@ -473,17 +489,17 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -473,17 +489,17 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu, seq_lens_cpu,
) )
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
if forward_mode.is_draft_extend(include_v2=True): if forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
elif forward_mode.is_draft_extend(include_v2=True):
accept_length = spec_info.accept_length[:bs] accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu: if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) + 1
metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs]) metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs]) + bs
else: else:
metadata.max_seq_len_q = 1 metadata.max_seq_len_q = 1
metadata.sum_seq_lens_q = bs metadata.sum_seq_lens_q = bs
...@@ -491,12 +507,15 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -491,12 +507,15 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
torch.cumsum(accept_length, dim=0, dtype=torch.int32) torch.cumsum(accept_length, dim=0, dtype=torch.int32)
) )
metadata.seq_lens_q.copy_(accept_length) metadata.seq_lens_q.copy_(accept_length)
# see NOTE(draft_extend seq_len handling)
seq_lens = seq_lens[:bs] - metadata.seq_lens_q + metadata.max_seq_len_q
metadata.seq_lens_k.copy_(seq_lens.to(torch.int32))
# Update block indices for new sequences. # Update block indices for new sequences.
create_flashmla_kv_indices_triton[(bs,)]( create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens,
None, None,
metadata.block_kv_indices, metadata.block_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
...@@ -538,7 +557,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -538,7 +557,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
or forward_batch.forward_mode.is_draft_extend(include_v2=True) or forward_batch.forward_mode.is_draft_extend(include_v2=True)
): ):
bs = forward_batch.batch_size bs = forward_batch.batch_size
self.forward_decode_metadata = TRTLLMMLADecodeMetadata()
# Get maximum sequence length. # Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None: if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item() max_seq = forward_batch.seq_lens_cpu.max().item()
...@@ -550,21 +569,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -550,21 +569,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if forward_batch.forward_mode.is_target_verify(): if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens
max_seqlen_pad = self._calc_padded_blocks(max_seq) elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
)
max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
block_kv_indices, max_seq_len_val
)
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item() max_seq = forward_batch.seq_lens_cpu.max().item()
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
...@@ -575,11 +581,26 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -575,11 +581,26 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
), ),
(1, 0), (1, 0),
) )
# see NOTE(draft_extend seq_len handling)
seq_lens = seq_lens - forward_batch.extend_seq_lens + max_seq_len_q
self.forward_decode_metadata.max_seq_len_q = max_seq_len_q self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
self.forward_decode_metadata.seq_lens_k = seq_lens
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
)
self.forward_decode_metadata.block_kv_indices = block_kv_indices
self.forward_decode_metadata.max_seq_len_k = int(max_seq)
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else: else:
...@@ -899,18 +920,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -899,18 +920,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
) )
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped) q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim) q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if ( if (
forward_batch.forward_mode.is_target_verify() forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True) or forward_batch.forward_mode.is_draft_extend(include_v2=True)
...@@ -936,23 +948,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -936,23 +948,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
bmm1_scale = q_scale * k_scale * layer.scaling bmm1_scale = q_scale * k_scale * layer.scaling
if forward_batch.forward_mode.is_target_verify(): if forward_batch.forward_mode.is_target_verify():
seq_lens = (
forward_batch.seq_lens.to(torch.int32)
+ forward_batch.spec_info.draft_token_num
)
max_seq_len = ( max_seq_len = (
metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
) )
else: else:
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens = (
forward_batch.seq_lens
- metadata.seq_lens_q
+ metadata.max_seq_len_q
).to(torch.int32)
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated) # Check if we're in CUDA graph mode (buffers are pre-allocated)
if self.padded_q_buffer is not None: if self.padded_q_buffer is not None:
...@@ -986,7 +985,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -986,7 +985,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices, block_tables=metadata.block_kv_indices,
seq_lens=seq_lens, seq_lens=metadata.seq_lens_k,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
) )
...@@ -1003,6 +1002,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -1003,6 +1002,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output return output
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA # MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_idx is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment