Unverified Commit 4418f599 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

Fix FA3 DeepSeek prefill performance regression (#5624)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent 04f2abcb
...@@ -583,13 +583,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -583,13 +583,17 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
elif self.attention_backend == "fa3": elif self.attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None:
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not self.disable_chunked_prefix_cache and not self.disable_chunked_prefix_cache
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) and (
>= self.chunked_prefix_cache_threshold sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
or sum_extend_prefix_lens == 0
)
): ):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment