Unverified Commit 27c6342e authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

Fix an argument issue when flash_attn>=2.5.7 (#1068)



fix an argument issue when flash_attn>=2.5.7
Signed-off-by: default avatarLi Tao <lit@nvidia.com>
Co-authored-by: default avatarLi Tao <lit@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 098e3006
...@@ -79,6 +79,7 @@ _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") ...@@ -79,6 +79,7 @@ _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
if _flash_attn_version >= _flash_attn_version_required: if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
...@@ -1292,6 +1293,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1292,6 +1293,8 @@ class AttnFuncWithCP(torch.autograd.Function):
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None fa_optional_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
# Flash Attn inputs # Flash Attn inputs
q_inputs = [None, None] q_inputs = [None, None]
...@@ -3448,6 +3451,8 @@ class FlashAttention(torch.nn.Module): ...@@ -3448,6 +3451,8 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if _flash_attn_2_4_1_plus: if _flash_attn_2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
output = flash_attn_forward_func( output = flash_attn_forward_func(
query_layer, query_layer,
key_layer, key_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment