Unverified Commit 22a6b9fc authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

Remove unnecessary metadata_expand.max_seq_len_k operations in fa3 to… (#7140)

parent b02df20a
......@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
......@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
),
(1, 0),
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
self.forward_metadata_spec_decode_expand = metadata_expand
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
......@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = (
self.speculative_step_id + 1
) # , do this in replay
metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1
......@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32,
)
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment