"tests/pytorch/test_quantized_tensor.py" did not exist on "b1a0e0a782186a4b9510b1232200befc191162e0"
Unverified Commit 161b1d98 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Drop FA as an installation requirement (#1226)



* WIP: make FA2 optional
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: fix logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add L1 test to test all supported FA versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update version to 2.1.1 and trim L1 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update onnxruntime version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove onnxruntime from L1 FA versions tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 43b9e1ee
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1
FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1)
for fa_version in "${FA_versions[@]}"
do
if [ "${fa_version}" \< "3.0.0" ]
then
pip install flash-attn==${fa_version}
else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
fi
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
......@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
......
......@@ -20,9 +20,8 @@ from transformer_engine.pytorch.attention import (
MultiheadAttention,
RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
_flash_attn_3_plus,
_flash_attn_3_is_installed,
check_set_window_size,
AttentionParams,
_attention_backends,
......@@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
......@@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......
......@@ -12,6 +12,7 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
import functools
from dataclasses import dataclass, fields
import numpy as np
......@@ -95,21 +96,76 @@ _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
fa_logger = logging.getLogger()
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
fa_logger.addHandler(_stream_handler)
@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
return ">= " + str(version_min) + ", " + "<= " + str(version_max)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
_flash_attn_max_version = PkgVersion("2.6.3")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
_flash_attn_3_plus = False
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
if get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip install flash-attn".""",
)
else:
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd,
)
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd,
)
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_flash_attn_is_installed = True
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
elif get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.warning(
"Supported flash-attn versions are %s. Found flash-attn %s.",
_get_supported_versions(
_flash_attn_version_required,
_flash_attn_max_version,
),
_flash_attn_version,
)
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
_use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
......@@ -117,18 +173,11 @@ _flash_attn_3_installation_steps = """\
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
_flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
_flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
fa3_logger = logging.getLogger()
fa3_logger.setLevel(_log_level)
if not fa3_logger.hasHandlers():
fa3_logger.addHandler(_stream_handler)
fa3_logger.debug(
"To use flash-attn v3, please follow these steps to install the flashattn-hopper "
"package: \n%s",
if get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
_flash_attn_3_installation_steps,
)
else:
......@@ -143,13 +192,10 @@ else:
_flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
)
_flash_attn_3_is_installed = True
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as flash_attn_varlen_fwd
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as flash_attn_varlen_bwd
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_attention_backends = {
"attention_params": None,
......@@ -319,7 +365,12 @@ def get_attention_backend(
+ str(
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
),
"flash_attn_version": _flash_attn_version,
"flash_attn_version": (
str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
),
"flash_attn_3_version": (
str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
),
"cudnn_version": ".".join([str(i) for i in cudnn_version]),
}
attention_params_dict = {
......@@ -330,15 +381,17 @@ def get_attention_backend(
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3
# Filter: Environment variables
global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
use_flash_attention = _NVTE_FLASH_ATTN
use_fused_attention = _NVTE_FUSED_ATTN
use_unfused_attention = _NVTE_UNFUSED_ATTN
if not use_flash_attention:
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
......@@ -347,7 +400,7 @@ def get_attention_backend(
# Filter: ONNX mode
if is_in_onnx_export_mode():
if use_flash_attention:
if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to ONNX mode")
use_flash_attention = False
if use_fused_attention:
......@@ -355,32 +408,31 @@ def get_attention_backend(
use_fused_attention = False
# Filter: Compute capability
global _use_flash_attn_3
if device_compute_capability < (8, 0):
if use_flash_attention:
if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _use_flash_attn_3:
if use_flash_attention and _flash_attn_3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False
_use_flash_attn_3 = False
# Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention:
if use_flash_attention and _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_flash_attention = False
use_flash_attention = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
......@@ -393,7 +445,8 @@ def get_attention_backend(
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not _use_flash_attn_3:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and is_training:
logger.debug(
......@@ -406,22 +459,24 @@ def get_attention_backend(
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FlashAttention as it does not support MLA.")
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
):
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
......@@ -438,10 +493,11 @@ def get_attention_backend(
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
# Filter: Dropout
......@@ -468,34 +524,39 @@ def get_attention_backend(
use_unfused_attention = False
if context_parallel and use_flash_attention:
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with attention"
" bias for THD format"
)
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
......@@ -551,7 +612,7 @@ def get_attention_backend(
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention:
if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
......@@ -571,28 +632,32 @@ def get_attention_backend(
_use_flash_attn_3 = False
if (
use_flash_attention
and _flash_attn_2_1_plus
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if _flash_attn_2_1_plus:
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not _flash_attn_is_installed:
_flash_attn_max_version = PkgVersion("2.1")
if (
use_flash_attention
and not _flash_attn_2_1_plus
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.1")
elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and _use_flash_attn_3
......@@ -645,15 +710,19 @@ def get_attention_backend(
attn_mask_type,
)
use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and not _flash_attn_2_3_plus
):
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if _use_flash_attn_3:
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.3")
elif not _flash_attn_2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
......@@ -668,7 +737,9 @@ def get_attention_backend(
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _use_flash_attn_3 and not _flash_attn_2_4_plus:
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.4")
elif not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False
......@@ -676,7 +747,8 @@ def get_attention_backend(
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
......@@ -780,13 +852,16 @@ def get_attention_backend(
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if use_flash_attention and deterministic:
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.4.1")
elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
......@@ -805,6 +880,23 @@ def get_attention_backend(
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
_get_supported_versions(
_flash_attn_version_required,
_flash_attn_max_version,
),
)
if use_flash_attention and not _flash_attn_is_installed:
use_flash_attention = False
available_backends[0] = False
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
......@@ -5029,12 +5121,13 @@ class FlashAttention(torch.nn.Module):
) -> None:
super().__init__()
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
if _flash_attn_is_installed:
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx
......@@ -5305,7 +5398,7 @@ class FlashAttention(torch.nn.Module):
if _flash_attn_3_0_0_beta:
e.args = (
e.args[0]
+ ". Please update your FlashAttention 3 (beta) installation as it "
+ ". Please update your flash-attn v3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps,
) + e.args[1:]
......@@ -7900,7 +7993,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3
global _attention_backends, _use_flash_attn_3
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......@@ -7908,7 +8001,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_plus
_use_flash_attn_3 = _flash_attn_3_is_installed
(
use_flash_attention,
use_fused_attention,
......@@ -7919,7 +8012,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version,
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
)
elif use_fused_attention:
self.logger.info(
......
......@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"],
install_requires=["torch"],
tests_require=["numpy", "onnxruntime", "torchvision"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment