Unverified Commit 161b1d98 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Drop FA as an installation requirement (#1226)



* WIP: make FA2 optional
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: fix logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add L1 test to test all supported FA versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update version to 2.1.1 and trim L1 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update onnxruntime version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove onnxruntime from L1 FA versions tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 43b9e1ee
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1
FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1)
for fa_version in "${FA_versions[@]}"
do
if [ "${fa_version}" \< "3.0.0" ]
then
pip install flash-attn==${fa_version}
else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
fi
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
...@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"]) install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
......
...@@ -20,9 +20,8 @@ from transformer_engine.pytorch.attention import ( ...@@ -20,9 +20,8 @@ from transformer_engine.pytorch.attention import (
MultiheadAttention, MultiheadAttention,
RotaryPositionEmbedding, RotaryPositionEmbedding,
get_attention_backend, get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus, _flash_attn_2_3_plus,
_flash_attn_3_plus, _flash_attn_3_is_installed,
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
_attention_backends, _attention_backends,
...@@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
if RoPE: if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
...@@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
......
...@@ -12,6 +12,7 @@ import os ...@@ -12,6 +12,7 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
import logging import logging
import functools
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
import numpy as np import numpy as np
...@@ -95,21 +96,76 @@ _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] ...@@ -95,21 +96,76 @@ _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler() _stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter) _stream_handler.setFormatter(_formatter)
fa_logger = logging.getLogger()
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
fa_logger.addHandler(_stream_handler)
@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
return ">= " + str(version_min) + ", " + "<= " + str(version_max)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6") # Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
_flash_attn_max_version = PkgVersion("2.6.3") _flash_attn_max_version = PkgVersion("2.6.3")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_plus = False
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_1_plus = False
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_3_plus = False
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") _flash_attn_2_6_0_plus = False
_flash_attn_3_plus = False try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
if get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip install flash-attn".""",
)
else:
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd,
)
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd,
)
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_flash_attn_is_installed = True
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
elif get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.warning(
"Supported flash-attn versions are %s. Found flash-attn %s.",
_get_supported_versions(
_flash_attn_version_required,
_flash_attn_max_version,
),
_flash_attn_version,
)
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
_use_flash_attn_3 = False _use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\ _flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" (1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
...@@ -117,18 +173,11 @@ _flash_attn_3_installation_steps = """\ ...@@ -117,18 +173,11 @@ _flash_attn_3_installation_steps = """\
(3) mkdir -p $python_path/flashattn_hopper (3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" (4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
try: try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
_flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
except PackageNotFoundError: except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: if get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
fa3_logger = logging.getLogger() fa_logger.debug(
fa3_logger.setLevel(_log_level) "flash-attn v3 is not installed. To use, please install it by \n%s",
if not fa3_logger.hasHandlers():
fa3_logger.addHandler(_stream_handler)
fa3_logger.debug(
"To use flash-attn v3, please follow these steps to install the flashattn-hopper "
"package: \n%s",
_flash_attn_3_installation_steps, _flash_attn_3_installation_steps,
) )
else: else:
...@@ -143,13 +192,10 @@ else: ...@@ -143,13 +192,10 @@ else:
_flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
) )
_flash_attn_3_is_installed = True
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True _use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as flash_attn_varlen_fwd
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as flash_attn_varlen_bwd
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_attention_backends = { _attention_backends = {
"attention_params": None, "attention_params": None,
...@@ -319,7 +365,12 @@ def get_attention_backend( ...@@ -319,7 +365,12 @@ def get_attention_backend(
+ str( + str(
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1]) (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
), ),
"flash_attn_version": _flash_attn_version, "flash_attn_version": (
str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
),
"flash_attn_3_version": (
str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
),
"cudnn_version": ".".join([str(i) for i in cudnn_version]), "cudnn_version": ".".join([str(i) for i in cudnn_version]),
} }
attention_params_dict = { attention_params_dict = {
...@@ -330,15 +381,17 @@ def get_attention_backend( ...@@ -330,15 +381,17 @@ def get_attention_backend(
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config) logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3
# Filter: Environment variables # Filter: Environment variables
global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) if not use_flash_attention and _flash_attn_is_installed:
use_flash_attention = _NVTE_FLASH_ATTN
use_fused_attention = _NVTE_FUSED_ATTN
use_unfused_attention = _NVTE_UNFUSED_ATTN
if not use_flash_attention:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention: if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
...@@ -347,7 +400,7 @@ def get_attention_backend( ...@@ -347,7 +400,7 @@ def get_attention_backend(
# Filter: ONNX mode # Filter: ONNX mode
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
if use_flash_attention: if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to ONNX mode") logger.debug("Disabling FlashAttention due to ONNX mode")
use_flash_attention = False use_flash_attention = False
if use_fused_attention: if use_fused_attention:
...@@ -355,32 +408,31 @@ def get_attention_backend( ...@@ -355,32 +408,31 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
# Filter: Compute capability # Filter: Compute capability
global _use_flash_attn_3
if device_compute_capability < (8, 0): if device_compute_capability < (8, 0):
if use_flash_attention: if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+") logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False use_flash_attention = False
if use_fused_attention: if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+") logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False use_fused_attention = False
if device_compute_capability < (9, 0): if device_compute_capability < (9, 0):
if use_flash_attention and _use_flash_attn_3: if use_flash_attention and _flash_attn_3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False _use_flash_attn_3 = False
# Filter: Data type # Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
torch.Tensor, torch.Tensor,
Float8Tensor, Float8Tensor,
]: ]:
if use_flash_attention: if use_flash_attention and _flash_attn_is_installed:
logger.debug( logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. " "Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.", "Found: qkv_dtype = %s.",
qkv_dtype, qkv_dtype,
) )
use_flash_attention = False use_flash_attention = False
if use_fused_attention: if use_fused_attention:
logger.debug( logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. " "Disabling FusedAttention due to unsupported QKV data type. "
...@@ -393,7 +445,8 @@ def get_attention_backend( ...@@ -393,7 +445,8 @@ def get_attention_backend(
# Filter: Execution type # Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not _use_flash_attn_3: if use_flash_attention and not _use_flash_attn_3:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and is_training: if use_flash_attention and _use_flash_attn_3 and is_training:
logger.debug( logger.debug(
...@@ -406,22 +459,24 @@ def get_attention_backend( ...@@ -406,22 +459,24 @@ def get_attention_backend(
# Filter: Head dimension # Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v: if use_flash_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FlashAttention as it does not support MLA.") if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False use_flash_attention = False
if use_flash_attention and ( if use_flash_attention and (
head_dim_qk > 256 head_dim_qk > 256
or head_dim_qk % 8 != 0 or head_dim_qk % 8 != 0
or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
): ):
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " logger.debug(
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"head_dim_qk <= 256 (>192 requires sm80/90). " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", "head_dim_qk <= 256 (>192 requires sm80/90). "
head_dim_qk, "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_v, head_dim_qk,
".".join([str(i) for i in device_compute_capability]), head_dim_v,
) ".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
...@@ -438,10 +493,11 @@ def get_attention_backend( ...@@ -438,10 +493,11 @@ def get_attention_backend(
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False use_unfused_attention = False
if use_flash_attention and pad_between_seqs: if use_flash_attention and pad_between_seqs:
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention for qkv_format = thd when there is " logger.debug(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" "Disabling FlashAttention for qkv_format = thd when there is "
) "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False use_flash_attention = False
# Filter: Dropout # Filter: Dropout
...@@ -468,34 +524,39 @@ def get_attention_backend( ...@@ -468,34 +524,39 @@ def get_attention_backend(
use_unfused_attention = False use_unfused_attention = False
if context_parallel and use_flash_attention: if context_parallel and use_flash_attention:
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention as it does not support context parallelism with FP8" logger.debug(
) "Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False use_flash_attention = False
if "bottom_right" in attn_mask_type: if "bottom_right" in attn_mask_type:
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention as it does not support context parallelism with" logger.debug(
" causal_bottom_right masking" "Disabling FlashAttention as it does not support context parallelism with"
) " causal_bottom_right masking"
)
use_flash_attention = False use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention as it does not support context parallelism with causal" logger.debug(
" masking for cross-attention" "Disabling FlashAttention as it does not support context parallelism with"
) " causal masking for cross-attention"
)
use_flash_attention = False use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention as it does not support context parallelism with bias type" logger.debug(
" of %s", "Disabling FlashAttention as it does not support context parallelism with bias"
core_attention_bias_type, " type of %s",
) core_attention_bias_type,
)
use_flash_attention = False use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias": elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug( if _flash_attn_is_installed:
"Disabling FlashAttention as it does not support context parallelism with attention" logger.debug(
" bias for THD format" "Disabling FlashAttention as it does not support context parallelism with"
) " attention bias for THD format"
)
use_flash_attention = False use_flash_attention = False
if context_parallel and use_fused_attention: if context_parallel and use_fused_attention:
...@@ -551,7 +612,7 @@ def get_attention_backend( ...@@ -551,7 +612,7 @@ def get_attention_backend(
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] | # | [b, h, sq, skv] |
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
if use_flash_attention: if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for arbitrary mask") logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False use_flash_attention = False
if use_fused_attention: if use_fused_attention:
...@@ -571,28 +632,32 @@ def get_attention_backend( ...@@ -571,28 +632,32 @@ def get_attention_backend(
_use_flash_attn_3 = False _use_flash_attn_3 = False
if ( if (
use_flash_attention use_flash_attention
and _flash_attn_2_1_plus
and attn_mask_type in ["causal", "padding_causal"] and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv and max_seqlen_q != max_seqlen_kv
): ):
logger.warning( if _flash_attn_2_1_plus:
"Disabling FlashAttention as it only supports bottom-right-diagonal " logger.warning(
"causal mask since flash-attn 2.1. See " "Disabling FlashAttention as it only supports bottom-right-diagonal "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" "causal mask since flash-attn 2.1. See "
) "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
use_flash_attention = False )
use_flash_attention = False
if not _flash_attn_is_installed:
_flash_attn_max_version = PkgVersion("2.1")
if ( if (
use_flash_attention use_flash_attention
and not _flash_attn_2_1_plus
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv and max_seqlen_q != max_seqlen_kv
): ):
logger.warning( if not _flash_attn_is_installed:
"Disabling FlashAttention as it only supports top-left-diagonal " _flash_attn_version_required = PkgVersion("2.1")
"causal mask before flash-attn 2.1. See " elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" logger.warning(
) "Disabling FlashAttention as it only supports top-left-diagonal "
use_flash_attention = False "causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if ( if (
use_flash_attention use_flash_attention
and _use_flash_attn_3 and _use_flash_attn_3
...@@ -645,15 +710,19 @@ def get_attention_backend( ...@@ -645,15 +710,19 @@ def get_attention_backend(
attn_mask_type, attn_mask_type,
) )
use_fused_attention = False use_fused_attention = False
if ( if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
use_flash_attention if _use_flash_attn_3:
and (window_size[0] != -1 or window_size[1] not in [-1, 0]) logger.debug(
and not _flash_attn_2_3_plus "Disabling FlashAttention 3 as it does not support sliding window attention"
): )
logger.debug( _use_flash_attn_3 = False
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" if not _flash_attn_is_installed:
) _flash_attn_version_required = PkgVersion("2.3")
use_flash_attention = False elif not _flash_attn_2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
# Filter: Attention bias # Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment # backend | bias types | ALiBi diagonal alignment
...@@ -668,7 +737,9 @@ def get_attention_backend( ...@@ -668,7 +737,9 @@ def get_attention_backend(
if _use_flash_attn_3: if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi") logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False _use_flash_attn_3 = False
if not _use_flash_attn_3 and not _flash_attn_2_4_plus: if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.4")
elif not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False use_flash_attention = False
...@@ -676,7 +747,8 @@ def get_attention_backend( ...@@ -676,7 +747,8 @@ def get_attention_backend(
core_attention_bias_type not in ["no_bias", "alibi"] core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None or core_attention_bias_shape is not None
): ):
logger.debug("Disabling FlashAttention for pre/post_scale_bias") if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_type = core_attention_bias_type
...@@ -780,13 +852,16 @@ def get_attention_backend( ...@@ -780,13 +852,16 @@ def get_attention_backend(
# | otherwise: no # | otherwise: no
# sub-backend 2 | no # sub-backend 2 | no
# UnfusedDotProductAttention | yes # UnfusedDotProductAttention | yes
if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus: if use_flash_attention and deterministic:
logger.warning( if not _flash_attn_is_installed:
"Disabling FlashAttention as version <2.4.1 does not support deterministic " _flash_attn_version_required = PkgVersion("2.4.1")
"execution. To use FlashAttention with deterministic behavior, " elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
"please install flash-attn >= 2.4.1." logger.warning(
) "Disabling FlashAttention as version <2.4.1 does not support deterministic "
use_flash_attention = False "execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if use_fused_attention and deterministic: if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons") logger.debug("Disabling FusedAttention for determinism reasons")
...@@ -805,6 +880,23 @@ def get_attention_backend( ...@@ -805,6 +880,23 @@ def get_attention_backend(
# All available backends # All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
_get_supported_versions(
_flash_attn_version_required,
_flash_attn_max_version,
),
)
if use_flash_attention and not _flash_attn_is_installed:
use_flash_attention = False
available_backends[0] = False
logger.debug( logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s," "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}", " UnfusedDotProductAttention=%s}",
...@@ -5029,12 +5121,13 @@ class FlashAttention(torch.nn.Module): ...@@ -5029,12 +5121,13 @@ class FlashAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( if _flash_attn_is_installed:
_flash_attn_version >= _flash_attn_version_required assert (
), f"FlashAttention minimum version {_flash_attn_version_required} is required." _flash_attn_version >= _flash_attn_version_required
assert ( ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
_flash_attn_version <= _flash_attn_max_version assert (
), f"FlashAttention maximum version {_flash_attn_max_version} is supported." _flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
...@@ -5305,7 +5398,7 @@ class FlashAttention(torch.nn.Module): ...@@ -5305,7 +5398,7 @@ class FlashAttention(torch.nn.Module):
if _flash_attn_3_0_0_beta: if _flash_attn_3_0_0_beta:
e.args = ( e.args = (
e.args[0] e.args[0]
+ ". Please update your FlashAttention 3 (beta) installation as it " + ". Please update your flash-attn v3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n" + "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps, + _flash_attn_3_installation_steps,
) + e.args[1:] ) + e.args[1:]
...@@ -7900,7 +7993,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7900,7 +7993,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3 global _attention_backends, _use_flash_attn_3
if ( if (
_attention_backends["attention_params"] is None _attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"] or attention_params != _attention_backends["attention_params"]
...@@ -7908,7 +8001,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7908,7 +8001,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params _attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]: if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_plus _use_flash_attn_3 = _flash_attn_3_is_installed
( (
use_flash_attention, use_flash_attention,
use_fused_attention, use_fused_attention,
...@@ -7919,7 +8012,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7919,7 +8012,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_flash_attention: if use_flash_attention:
self.logger.info( self.logger.info(
"Running with FlashAttention backend (version %s)", "Running with FlashAttention backend (version %s)",
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version, _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
) )
elif use_fused_attention: elif use_fused_attention:
self.logger.info( self.logger.info(
......
...@@ -56,7 +56,7 @@ if __name__ == "__main__": ...@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"], install_requires=["torch"],
tests_require=["numpy", "onnxruntime", "torchvision"], tests_require=["numpy", "onnxruntime", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment