Unverified Commit af1973b8 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix max_seq_len_k in trtllm_mha attention backend (#9416)

parent 5cfbb4c1
...@@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length # Precompute maximum sequence length
metadata.max_seq_len_k = self.max_context_len metadata.max_seq_len_k = seq_lens[:bs].max().item()
# Precompute page table # Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
...@@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = self.max_context_len metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[ page_indices = self.req_to_token[
...@@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
workspace_buffer=self.workspace_buffer, workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table, block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32, seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.forward_metadata.max_seq_len_k, max_seq_len=self.max_context_len,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale, bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
...@@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
block_tables=self.forward_metadata.page_table, block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32, seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q, max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.forward_metadata.max_seq_len_k, max_kv_len=self.max_context_len,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale, bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size, batch_size=forward_batch.batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment