Unverified Commit 7b064f04 authored by Mahmoud Ashraf's avatar Mahmoud Ashraf Committed by GitHub
Browse files

[bugfix]: use correct causality condition for flashattention, flashinfer, and...

[bugfix]: use correct causality condition for flashattention, flashinfer, and triton backends (#10172)
parent 43190bec
...@@ -705,7 +705,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -705,7 +705,9 @@ class FlashAttentionBackend(AttentionBackend):
q = q.to(self.kv_cache_dtype) q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = not layer.is_cross_attention causal = True
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
# Check if we should use local attention # Check if we should use local attention
use_local_attn = ( use_local_attn = (
...@@ -1005,7 +1007,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1005,7 +1007,9 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None and layer.sliding_window_size > -1 if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1) else (-1, -1)
) )
causal = not layer.is_cross_attention causal = True
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
# For fa3 interface version compatibility, we put new fields into conditional keyword args # For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {} kwargs = {}
......
...@@ -728,9 +728,10 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -728,9 +728,10 @@ class FlashInferAttnBackend(AttentionBackend):
) )
else: else:
causal = True causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY: if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
causal = False causal = False
if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
if self.forward_metadata.extend_no_prefix: if self.forward_metadata.extend_no_prefix:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
......
...@@ -794,7 +794,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -794,7 +794,7 @@ class TritonAttnBackend(AttentionBackend):
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
causal = True causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY: if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False causal = False
if layer.sliding_window_size is not None and layer.sliding_window_size > -1: if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment