Unverified Commit 205d5cb4 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

perf: Optimize local attention memory allocation in FlashAttentionBackend (#6356)

parent 3d7f7a43
......@@ -1434,19 +1434,7 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata[bs] = metadata
if self.attention_chunk_size is not None:
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
],
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
"local_seqused_k"
],
local_block_table=self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
],
local_max_query_len=1,
local_max_seq_len=1,
)
self._update_local_attn_metadata_for_capture(metadata, batch_size)
elif forward_mode.is_target_verify():
if self.topk <= 1:
......@@ -1807,6 +1795,62 @@ class FlashAttentionBackend(AttentionBackend):
)
metadata.local_attn_metadata = local_metadata
def _update_local_attn_metadata_for_capture(
self, metadata: FlashAttentionMetadata, bs: int
):
"""Update local attention metadata during CUDA graph capture phase.
This method calculates the exact buffer sizes needed for local attention metadata
during the CUDA graph capture phase, optimizing memory usage by creating views of
pre-allocated buffers with exactly the sizes needed.
"""
seq_lens_capture = metadata.cache_seqlens_int32
max_seq_len = int(seq_lens_capture.max().item())
page_table_capture = metadata.page_table
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seqlens_np = seq_lens_capture.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local_np,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seqlens_np,
page_table_capture,
self.page_size,
)
# Get exact dimensions from the calculation
q_len = len(cu_seqlens_q_local_np)
k_len = len(seqlens_k_local_np)
b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
# Create views of the pre-allocated buffers with exactly these sizes
# This is the key optimization - we only use the memory we actually need
local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
][:q_len]
local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
:k_len
]
local_block_table = self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
][:b0, :b1]
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=local_block_table,
local_max_query_len=1,
local_max_seq_len=max_seq_len,
)
def _update_local_attn_metadata_for_replay(
self, metadata: FlashAttentionMetadata, bs: int
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment