Unverified Commit e51046be authored by yinghui's avatar yinghui Committed by GitHub
Browse files

perf: trtllm_mla attention backend spec decoding speedup w/ cuda graph (#12093)

parent 4eeeae1e
...@@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size, PAGED_SIZE=self.page_size,
) )
# Record the true maximum sequence length for this capture batch so that
# the kernel launch path (which requires an int not a tensor) can reuse
# it safely during both capture and replay.
max_seq_len_val = int(seq_lens.max().item())
metadata = TRTLLMMLADecodeMetadata( metadata = TRTLLMMLADecodeMetadata(
block_kv_indices, block_kv_indices,
max_seq_len_val, self.max_context_len,
) )
if forward_mode.is_draft_extend(include_v2=True): if forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs num_tokens_per_bs = num_tokens // bs
...@@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size, PAGED_SIZE=self.page_size,
) )
# Update stored max_seq_len so subsequent kernel calls use the correct value
# Prefer CPU tensor to avoid GPU synchronization when available.
if seq_lens_cpu is not None:
metadata.max_seq_len = int(seq_lens_cpu.max().item())
else:
metadata.max_seq_len = int(seq_lens.max().item())
def get_cuda_graph_seq_len_fill_value(self) -> int: def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph.""" """Get the fill value for sequence lengths in CUDA graph."""
return 1 return 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment