Unverified Commit aee62d74 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent cd7e32e2
...@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k + self.page_size - 1 metadata.max_seq_len_k + self.page_size - 1
) // self.page_size ) // self.page_size
page_indices = self.req_to_token[ page_indices = self.req_to_token[
:, req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
] ]
page_indices = page_indices[req_pool_indices] // self.page_size page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0) metadata.page_table[:, max_seq_pages:].fill_(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment