"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "30f8a68c4cc55e0f3a717b891931847c97190843"
Unverified Commit cded039b authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)

parent 275f9df3
...@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
# This is being used by normal decode and draft decode when topk == 1 # This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
...@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
), ),
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, max_num_pages,
dtype=torch.int32,
device=self.device,
),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
...@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
0, self.max_context_len, self.page_size, device=self.device 0, self.max_context_len, self.page_size, device=self.device
), ),
} }
# Only allocate local attention buffers if local attention is enabled # Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used # This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None: if self.attention_chunk_size is not None:
...@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_num_draft_tokens is not None self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0 and self.speculative_num_draft_tokens > 0
): ):
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
)
self.target_verify_metadata = { self.target_verify_metadata = {
"cache_seqlens": torch.zeros( "cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device max_bs, dtype=torch.int32, device=self.device
...@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
), ),
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, max_num_pages,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
...@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
), ),
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, max_num_pages,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment