Unverified Commit 79961afa authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

optimize pad operations in fa3 to accelarate 100+us (#6077)

parent cfca4e0e
...@@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1 self.speculative_step_id + 1
) )
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(
torch.cumsum( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
) )
) )
...@@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend):
# metadata.max_seq_len_q = self.topk, already set in capture # metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture # metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(
torch.cumsum( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
) )
) )
...@@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = ( metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
) )
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
) )
max_seq_pages = ( max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1 metadata.max_seq_len_k + self.page_size - 1
...@@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend):
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture # metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
) )
page_table = self.req_to_token[ page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k req_pool_indices, : metadata.max_seq_len_k
...@@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.cache_seqlens_int32.copy_( metadata_expand.cache_seqlens_int32.copy_(
mask.sum(dim=1).to(torch.int32) mask.sum(dim=1).to(torch.int32)
) )
metadata_expand.cu_seqlens_k.copy_( metadata_expand.cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(
torch.cumsum( metadata_expand.cache_seqlens_int32,
metadata_expand.cache_seqlens_int32, dim=0,
dim=0, dtype=torch.int32,
dtype=torch.int32,
),
(1, 0),
) )
) )
metadata_expand.max_seq_len_k = ( metadata_expand.max_seq_len_k = (
...@@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend):
# Only support encoder size 1 for now # Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0] metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1]) metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k.copy_( metadata.encoder_cu_seqlens_k[1:].copy_(
torch.nn.functional.pad( torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
) )
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_( metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment