"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "d8b13cb0e58060e2428949faa36e625fe2978d97"
Unverified Commit 161b1d98 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Drop FA as an installation requirement (#1226)



* WIP: make FA2 optional
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: fix logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add L1 test to test all supported FA versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update version to 2.1.1 and trim L1 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update onnxruntime version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove onnxruntime from L1 FA versions tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 43b9e1ee
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1
FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1)
for fa_version in "${FA_versions[@]}"
do
if [ "${fa_version}" \< "3.0.0" ]
then
pip install flash-attn==${fa_version}
else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
fi
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
...@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"]) install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
......
...@@ -20,9 +20,8 @@ from transformer_engine.pytorch.attention import ( ...@@ -20,9 +20,8 @@ from transformer_engine.pytorch.attention import (
MultiheadAttention, MultiheadAttention,
RotaryPositionEmbedding, RotaryPositionEmbedding,
get_attention_backend, get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus, _flash_attn_2_3_plus,
_flash_attn_3_plus, _flash_attn_3_is_installed,
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
_attention_backends, _attention_backends,
...@@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
if RoPE: if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
...@@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training: if _flash_attn_3_is_installed and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
......
This diff is collapsed.
...@@ -56,7 +56,7 @@ if __name__ == "__main__": ...@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"], install_requires=["torch"],
tests_require=["numpy", "onnxruntime", "torchvision"], tests_require=["numpy", "onnxruntime", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment