Unverified Commit cded039b authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)

parent 275f9df3
......@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
......@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
......@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
0, self.max_context_len, self.page_size, device=self.device
),
}
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None:
......@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
)
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
......@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
......@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment