Unverified Commit 7a913301 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Save cuda graph memory for fa3 (#8567)

parent 5ce5093b
...@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
) )
metadata.page_table = self.decode_cuda_graph_metadata[ metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode" "page_table_draft_decode"
][req_pool_indices, :] ][:bs, :]
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
else: else:
# When top k > 1, we need two specific draft decode metadata, and then merge states # When top k > 1, we need two specific draft decode metadata, and then merge states
...@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
][: bs + 1] ][: bs + 1]
metadata.page_table = self.draft_decode_metadata_topk_normal[ metadata.page_table = self.draft_decode_metadata_topk_normal[
"page_table" "page_table"
][req_pool_indices, :] ][:bs, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = ( metadata_expand.cache_seqlens_int32 = (
...@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = seq_lens.max().item() metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table # Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, : :bs, :
] ]
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
...@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
: (bs + 1) : (bs + 1)
] ]
metadata.page_table = self.target_verify_metadata["page_table"][ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
req_pool_indices, :
]
self.target_verify_metadata[bs] = metadata self.target_verify_metadata[bs] = metadata
else: else:
...@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
][: bs + 1] ][: bs + 1]
metadata.page_table = self.target_verify_metadata_topk_normal[ metadata.page_table = self.target_verify_metadata_topk_normal[
"page_table" "page_table"
][req_pool_indices, :] ][:bs, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = ( metadata_expand.cache_seqlens_int32 = (
...@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][ metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1) : (bs + 1)
] ]
metadata.page_table = self.draft_extend_metadata["page_table"][ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
req_pool_indices, :
]
self.draft_extend_metadata[bs] = metadata self.draft_extend_metadata[bs] = metadata
...@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
][: (encoder_bs + 1)] ][: (encoder_bs + 1)]
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][ metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
req_pool_indices, : :bs, :
] ]
self.forward_metadata = metadata self.forward_metadata = metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment