Unverified Commit 10cceae9 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Move `block_table` argument to FA varlen function (#1222)



move block_table arg to varlen_func section
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent fb749619
...@@ -5012,12 +5012,12 @@ class FlashAttention(torch.nn.Module): ...@@ -5012,12 +5012,12 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if _flash_attn_2_4_1_plus: if _flash_attn_2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
fa_optional_forward_args_thd = [] fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
else: else:
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
func = ( func = (
flash_attn_varlen_func flash_attn_varlen_func
if not _use_flash_attn_3 if not _use_flash_attn_3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment