Unverified Commit 6fc17596 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Optimize a pad operation to accelerate 25us (#5945)

parent ad506a4e
...@@ -1587,8 +1587,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1587,8 +1587,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( # Optimize cumulative sequence length calculation
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
) )
max_seq_pages = ( max_seq_pages = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment