Unverified Commit aee62d74 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent cd7e32e2
......@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
page_indices = page_indices[req_pool_indices] // self.page_size
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment