Unverified Commit 22a6b9fc authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

Remove unnecessary metadata_expand.max_seq_len_k operations in fa3 to… (#7140)

parent b02df20a
...@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
) )
metadata_expand.max_seq_len_q = 1 metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
metadata_expand.cu_seqlens_q = torch.arange( metadata_expand.cu_seqlens_q = torch.arange(
0, 0,
metadata_expand.cache_seqlens_int32.numel() + 1, metadata_expand.cache_seqlens_int32.numel() + 1,
...@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
), ),
(1, 0), (1, 0),
) )
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
self.forward_metadata_spec_decode_expand = metadata_expand self.forward_metadata_spec_decode_expand = metadata_expand
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
...@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
] ]
) )
metadata_expand.max_seq_len_q = 1 metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = (
self.speculative_step_id + 1
) # , do this in replay
metadata_expand.cu_seqlens_q = ( metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][ self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1 : bs * self.topk + 1
...@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
) )
) )
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend(): elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs] metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment