Unverified Commit 0d802283 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)



* exclude 9.10.0/.1 for certain configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kv_channels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add get_backend to tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add init files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix numerics and cuda graph tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor changes after renaming
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix import structure and rename get_attention_backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix docs and benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get backend calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix get backend calls"

This reverts commit 653cbb51c697bc2f975416bb3aac1d85f76c36dc.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix docs and benchmarks"

This reverts commit 98cd52e04ff7c53e26b412195f5744e39f7ed0e9.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix docs, benchmarks and pre-commit ci
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dpa/mha flash attn selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix rng states
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection on Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix issues from last merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update tests/pytorch/utils.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove initialization of rng_states to None
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* redefine ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix seed for CP tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update tests/pytorch/test_sanity.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move fixture from utils to individual tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ab5cc407
...@@ -9,11 +9,11 @@ import numpy as np ...@@ -9,11 +9,11 @@ import numpy as np
import torch import torch
import nvtx import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.utils import (
ModelConfig, ModelConfig,
_get_attention_backends, get_available_attention_backends,
_run_dot_product_attention,
) )
from tests.pytorch.attention.test_attention import _run_dot_product_attention
pd.set_option("display.precision", 4) pd.set_option("display.precision", 4)
...@@ -197,7 +197,7 @@ def main(): ...@@ -197,7 +197,7 @@ def main():
) )
for model in model_configs.keys(): for model in model_configs.keys():
config = model_configs[model] config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os import os
import torch import torch
from typing import Tuple from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state # Initialize RNG state
......
...@@ -375,7 +375,7 @@ ...@@ -375,7 +375,7 @@
"\n", "\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n", "\n",
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
] ]
}, },
{ {
...@@ -394,10 +394,10 @@ ...@@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n", "\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
] ]
}, },
{ {
...@@ -458,7 +458,7 @@ ...@@ -458,7 +458,7 @@
" </tr>\n", " </tr>\n",
"</table>\n", "</table>\n",
"\n", "\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note</b>\n", "<b>Note</b>\n",
...@@ -548,7 +548,7 @@ ...@@ -548,7 +548,7 @@
"id": "dda4a589", "id": "dda4a589",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n", "\n",
"### 3.3 Attention Bias\n", "### 3.3 Attention Bias\n",
"\n", "\n",
...@@ -594,7 +594,7 @@ ...@@ -594,7 +594,7 @@
"\n", "\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n", "\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
] ]
}, },
{ {
...@@ -612,7 +612,7 @@ ...@@ -612,7 +612,7 @@
"\n", "\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n", "\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
] ]
} }
], ],
......
...@@ -9,11 +9,11 @@ import numpy as np ...@@ -9,11 +9,11 @@ import numpy as np
import torch import torch
import nvtx import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.utils import (
ModelConfig, ModelConfig,
_get_attention_backends, get_available_attention_backends,
_run_dot_product_attention,
) )
from tests.pytorch.attention.test_attention import _run_dot_product_attention
# data type # data type
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -90,7 +90,7 @@ def main(): ...@@ -90,7 +90,7 @@ def main():
models = ["test_0"] models = ["test_0"]
for model in models: for model in models:
config = model_configs[model] config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ ...@@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
...@@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ ...@@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
...@@ -41,6 +41,6 @@ do ...@@ -41,6 +41,6 @@ do
fi fi
# Run tests # Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
done done
...@@ -372,7 +372,7 @@ class FusedAttnRunner: ...@@ -372,7 +372,7 @@ class FusedAttnRunner:
self.head_dim_v, self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size, (-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if ( if (
......
...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel ...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank, get_cu_seqlens_on_cp_rank,
) )
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import os import os
import subprocess import subprocess
import sys
import pathlib
import pytest import pytest
import torch import torch
...@@ -12,26 +14,28 @@ from transformer_engine.pytorch.utils import ( ...@@ -12,26 +14,28 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_1_3": ModelConfig( "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA ), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
} }
...@@ -43,7 +47,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -43,7 +47,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node), "--nproc-per-node=" + str(num_gpus_per_node),
] ]
te_path = os.getenv("TE_PATH", "/opt/transformerengine") te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path) args.append(script_path)
for k, v in kwargs.items(): for k, v in kwargs.items():
args.append(f"{k}={v}") args.append(f"{k}={v}")
...@@ -93,32 +97,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -93,32 +97,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA ), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA ), # GQA
"cp_3_0": ModelConfig( "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig( "cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
} }
...@@ -175,6 +183,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -175,6 +183,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -5,18 +5,14 @@ ...@@ -5,18 +5,14 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
import os import os
import sys
import pathlib
import logging import logging
import math import math
import pytest import pytest
import torch import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import ( ...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible, is_bf16_compatible,
) )
# Initialize RNG state _current_file = pathlib.Path(__file__).resolve()
seed = 1234 sys.path.append(str(_current_file.parent.parent))
torch.manual_seed(seed) from utils import (
torch.cuda.manual_seed(seed) ModelConfig,
_cpu_rng_state = torch.get_rng_state() reset_rng_states,
_cuda_rng_state = torch.cuda.get_rng_state() get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16] param_types = [torch.float16]
if is_bf16_compatible(): if is_bf16_compatible():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
model_configs_infer = { model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, sq, hq, dqk,
"infer_0": ModelConfig( "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
} }
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged: if is_paged:
qkv_layout = "paged_kv_" + qkv_layout qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -10,6 +10,8 @@ import torch ...@@ -10,6 +10,8 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -22,10 +24,13 @@ fp8_recipes = [ ...@@ -22,10 +24,13 @@ fp8_recipes = [
recipe.DelayedScaling(), recipe.DelayedScaling(),
] ]
SIZE = 512 model_config = {
NUM_HEADS = 8 "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
NUM_LAYERS = 5 }
EPSILON = 0.1 SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass # Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU. # that cannot be offloaded to CPU.
...@@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: ...@@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward( without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False models_list, fp8_recipe, False
) )
......
...@@ -23,7 +23,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -23,7 +23,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -32,27 +32,12 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( ...@@ -32,27 +32,12 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
) )
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
# Record initial RNG state. model_configs = {
seed = 1234 "small": ModelConfig(32, 2, 2, 32),
torch.manual_seed(seed) }
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [ fp8_recipes = [
recipe.DelayedScaling(), recipe.DelayedScaling(),
...@@ -67,12 +52,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -67,12 +52,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_global_fp8_state(): def reset_global_fp8_state():
yield yield
...@@ -107,7 +86,7 @@ def generate_data( ...@@ -107,7 +86,7 @@ def generate_data(
"""Generate synthetic data.""" """Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
return gen_func( return gen_func(
model_config.sequence_length, model_config.max_seqlen_q,
model_config.batch_size, model_config.batch_size,
model_config.hidden_size, model_config.hidden_size,
device="cuda", device="cuda",
...@@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention( ...@@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention(
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
return [ return [
gen_func( gen_func(
model_config.sequence_length, model_config.max_seqlen_q,
model_config.batch_size, model_config.batch_size,
model_config.num_heads, model_config.num_heads,
model_config.kv_channels, model_config.kv_channels,
...@@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs( ...@@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs(
( (
model_config.batch_size, model_config.batch_size,
1, 1,
model_config.sequence_length, model_config.max_seqlen_q,
model_config.sequence_length, model_config.max_seqlen_kv,
), ),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
...@@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs( ...@@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs(
( (
model_config.batch_size, model_config.batch_size,
1, 1,
model_config.sequence_length, model_config.max_seqlen_q,
model_config.sequence_length, model_config.max_seqlen_kv,
), ),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
......
This diff is collapsed.
This diff is collapsed.
...@@ -4,12 +4,24 @@ ...@@ -4,12 +4,24 @@
from __future__ import annotations from __future__ import annotations
import logging
import os
from contextlib import contextmanager
import pytest
import torch import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
...@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name == "fp8_block_scaling": if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling() return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})") raise ValueError(f"Unsupported quantization scheme ({name})")
# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def reset_rng_states() -> None:
"""Revert to deterministic RNG state"""
global _rng_states
if _rng_states is None:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
_rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
else:
cpu_rng_state, cuda_rng_state = _rng_states
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
class ModelConfig:
def __init__(
self,
batch_size: int,
max_seqlen_q: int,
num_heads: int,
head_dim_qk: int,
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv
self.num_heads = num_heads
self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
if self.head_dim_qk == self.head_dim_v:
self.kv_channels = self.head_dim_qk
else:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check for all available attention backends that support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
...@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) { !requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else { } else {
...@@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged // 9.10.2: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 && (!is_training && cudnn_runtime_version >= 91002 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
...@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) && dropout == 0.0)))) &&
// check 64-bit ragged offset support // check 64-bit ragged offset support
(supported_ragged_offset_size)) { (supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment