Unverified Commit af1973b8 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix max_seq_len_k in trtllm_mha attention backend (#9416)

parent 5cfbb4c1
......@@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length
metadata.max_seq_len_k = self.max_context_len
metadata.max_seq_len_k = seq_lens[:bs].max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
......@@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = self.max_context_len
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[
......@@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.forward_metadata.max_seq_len_k,
max_seq_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size,
......@@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.forward_metadata.max_seq_len_k,
max_kv_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment