Unverified Commit eb577e46 authored by Pei-Lun Liao's avatar Pei-Lun Liao Committed by GitHub
Browse files

[Bugfix] Add missing sink tensor into flash attn cascade attn implementation (#26325)

parent 8f36850f
...@@ -607,6 +607,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -607,6 +607,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale, q_descale=layer._q_scale,
k_descale=layer._k_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale, v_descale=layer._v_scale,
s_aux=self.sinks,
) )
return output return output
...@@ -767,6 +768,7 @@ def cascade_attention( ...@@ -767,6 +768,7 @@ def cascade_attention(
q_descale: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None,
s_aux: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi." assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window. # TODO: Support sliding window.
...@@ -801,6 +803,9 @@ def cascade_attention( ...@@ -801,6 +803,9 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment