Unverified Commit da55d247 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Disable FAv2.1+ for causal mask in cross attention (#522)



* disable FAv2.1 if causal+cross attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove comment and add warning
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* include both causal and padding+causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add a space
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 15088217
...@@ -56,6 +56,7 @@ from transformer_engine.pytorch.jit import jit_fuser ...@@ -56,6 +56,7 @@ from transformer_engine.pytorch.jit import jit_fuser
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
if _flash_attn_2_available: if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
...@@ -2134,6 +2135,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -2134,6 +2135,16 @@ class DotProductAttention(torch.nn.Module):
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False use_flash_attention = False
if (_flash_attn_2_1_plus
and causal_mask
and max_seqlen_q != max_seqlen_kv):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if core_attention_bias_type != "no_bias" or core_attention_bias is not None: if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False use_flash_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment