Unverified Commit dc4f5418 authored by yinghui's avatar yinghui Committed by GitHub
Browse files

fix trtllm_mla attention backend when disabling cuda graph. (#12687)

parent 0648eb48
......@@ -585,7 +585,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item()
......@@ -604,7 +604,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
self.forward_decode_metadata.seq_lens_k = seq_lens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment