Unverified Commit 6ffb6bd4 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Fix fa3 cuda graph page_size > 1 precision and page_size=1 speed (#4855)

parent 47e6628a
...@@ -322,10 +322,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -322,10 +322,13 @@ class FlashAttentionBackend(AttentionBackend):
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
metadata.page_table = self.req_to_token[ max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
:, self.decode_cuda_graph_metadata["strided_indices"] page_indices = self.req_to_token[
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
] ]
metadata.page_table = metadata.page_table[req_pool_indices[:bs]] page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self.forward_metadata = metadata self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment