Unverified Commit 0d802283 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)



* exclude 9.10.0/.1 for certain configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kv_channels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add get_backend to tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add init files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix numerics and cuda graph tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor changes after renaming
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix import structure and rename get_attention_backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix docs and benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get backend calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix get backend calls"

This reverts commit 653cbb51c697bc2f975416bb3aac1d85f76c36dc.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix docs and benchmarks"

This reverts commit 98cd52e04ff7c53e26b412195f5744e39f7ed0e9.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix docs, benchmarks and pre-commit ci
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dpa/mha flash attn selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix rng states
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection on Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix issues from last merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update tests/pytorch/utils.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove initialization of rng_states to None
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* redefine ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ModelConfig
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix seed for CP tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update tests/pytorch/test_sanity.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move fixture from utils to individual tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ab5cc407
......@@ -9,11 +9,11 @@ import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention
pd.set_option("display.precision", 4)
......@@ -197,7 +197,7 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
......@@ -5,7 +5,7 @@
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
......
......@@ -375,7 +375,7 @@
"\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n",
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
]
},
{
......@@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
]
},
{
......@@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
......@@ -548,7 +548,7 @@
"id": "dda4a589",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
......@@ -594,7 +594,7 @@
"\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
]
},
{
......@@ -612,7 +612,7 @@
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
]
}
],
......
......@@ -9,11 +9,11 @@ import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention
# data type
dtype = torch.bfloat16
......@@ -90,7 +90,7 @@ def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
......@@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
......@@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
......@@ -41,6 +41,6 @@ do
fi
# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
done
......@@ -372,7 +372,7 @@ class FusedAttnRunner:
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.")
if (
......
......@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
......
......@@ -4,8 +4,9 @@
import logging
import math
import os
import sys
import pathlib
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import pytest
import torch
......@@ -21,7 +22,6 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
......@@ -48,21 +48,22 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved,
)
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
ModelConfig,
dtype_tols,
logging_context,
get_available_attention_backends,
)
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
# Reset RNG states
reset_rng_states()
@pytest.fixture(autouse=True)
......@@ -71,170 +72,20 @@ def reset_global_fp8_state():
fp8.FP8GlobalStateManager.reset()
class ModelConfig:
def __init__(
self,
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128),
"base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
"base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
"base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
"base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
"base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
"base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
}
......@@ -278,7 +129,7 @@ def test_dot_product_attention(
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -289,7 +140,7 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -413,33 +264,19 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
}
......@@ -454,40 +291,46 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"mask_2_1": ModelConfig(
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
),
"mask_2_2": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
),
"mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
),
"mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_7_0": ModelConfig(
2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
),
"mask_7_1": ModelConfig(
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
),
"mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
),
}
......@@ -503,44 +346,102 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
"bias_1_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
), # skipped
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
"bias_2_1": ModelConfig(
2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
"bias_2_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_1": ModelConfig(
2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
"bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
"bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped
"bias_4_5": ModelConfig(
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
......@@ -555,33 +456,29 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
4,
16,
16,
64,
128,
2048,
24,
128,
0.0,
# mask, bias, bias_shape,
"no_mask",
"post_scale_bias",
bias_shape="11ss",
),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
attn_mask_type="causal",
attn_bias_type="alibi",
bias_shape="1hss",
alibi_type="custom",
),
"bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
2,
2048,
24,
128,
attn_mask_type="causal",
attn_bias_type="alibi",
bias_shape="bhss",
alibi_type="custom",
),
}
......@@ -597,29 +494,31 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
"swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
"swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"swa_3_2": ModelConfig(
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
),
"swa_3_3": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
),
"swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
"swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
"swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
),
"swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
}
......@@ -635,13 +534,31 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
"alibi_1_1": ModelConfig(
1,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="vanilla",
),
"alibi_2_0": ModelConfig(
2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
),
"alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
1,
1024,
24,
128,
max_seqlen_kv=2048,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="custom",
),
}
......@@ -671,16 +588,38 @@ qkv_layouts = [
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"layout_0_3": ModelConfig(
1,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_1_0": ModelConfig(2, 2048, 24, 128),
"layout_1_1": ModelConfig(
2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_3": ModelConfig(
1,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
"layout_2_1": ModelConfig(
2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
}
......@@ -697,55 +636,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_2_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"layout_1_2": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
),
"layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
),
"layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
"layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
),
"layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
),
"layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
"layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
),
"layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
),
"layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
),
"layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
2,
2048,
24,
128,
num_gqa_groups=1,
attn_mask_type="padding_causal_bottom_right",
window_size=(4, 0),
),
"layout_5_2": ModelConfig(
2,
24,
2048,
24,
128,
2048,
4096,
0.0,
"padding_causal_bottom_right",
"no_bias",
max_seqlen_kv=4096,
attn_mask_type="padding_causal_bottom_right",
window_size=(4, 0),
),
}
......@@ -1135,16 +1073,22 @@ def _run_dot_product_attention(
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"te_1_2": ModelConfig(
2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
),
"te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
"te_2_1": ModelConfig(2, 2048, 16, 64),
"te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
"te_2_3": ModelConfig(
1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
"te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
}
......@@ -1168,7 +1112,7 @@ def test_transformer_layer(
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
......@@ -1179,7 +1123,7 @@ def test_transformer_layer(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
......@@ -1492,20 +1436,164 @@ def _run_transformer_layer(
return out, inp.grad
model_configs_fp8_extra_state = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"),
"fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"),
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
"fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
"fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
"fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
"fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
......@@ -1554,18 +1642,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
# Test backend availability
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1591,11 +1691,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1768,23 +1864,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
# Test backend availability
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_layout,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1813,11 +1920,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1991,14 +2094,14 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128),
"fp8_4": ModelConfig(2, 2048, 24, 128),
"fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
"fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
"fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
"fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
......@@ -2027,6 +2130,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
......
......@@ -4,6 +4,8 @@
import os
import subprocess
import sys
import pathlib
import pytest
import torch
......@@ -12,26 +14,28 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
}
......@@ -43,7 +47,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
......@@ -93,32 +97,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
}
......@@ -175,6 +183,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......
......@@ -5,18 +5,14 @@
from collections import OrderedDict
from typing import List
import os
import sys
import pathlib
import logging
import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
ModelConfig,
reset_rng_states,
get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
# test: b, sq, hq, dqk,
"infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
"infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
}
qkv_formats = ["bshd", "sbhd", "thd"]
......@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
......@@ -10,6 +10,8 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -22,10 +24,13 @@ fp8_recipes = [
recipe.DelayedScaling(),
]
SIZE = 512
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
......@@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
......
......@@ -23,7 +23,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -32,27 +32,12 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
# Record initial RNG state.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
model_configs = {
"small": ModelConfig(32, 2, 2, 32),
}
fp8_recipes = [
recipe.DelayedScaling(),
......@@ -67,12 +52,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
......@@ -107,7 +86,7 @@ def generate_data(
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
return gen_func(
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.batch_size,
model_config.hidden_size,
device="cuda",
......@@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention(
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
......@@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
......@@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
......
......@@ -40,11 +40,13 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -56,33 +58,18 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
# Reset RNG states.
reset_rng_states()
torch._dynamo.config.recompile_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
"small": ModelConfig(1, 128, 8, 16, num_layers=4),
"126m": ModelConfig(1, 2048, 12, 64, num_layers=12),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
"126m": ModelConfig(1, 256, 12, 64, num_layers=12),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
......@@ -124,6 +111,18 @@ fp8_recipes = [
]
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
):
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
......@@ -173,12 +172,6 @@ def assert_allclose(
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
......@@ -531,13 +524,13 @@ def _test_e2e_selective_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -546,13 +539,13 @@ def _test_e2e_selective_recompute(
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
te_out = block(
......@@ -626,13 +619,13 @@ def _test_e2e_full_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -641,14 +634,14 @@ def _test_e2e_full_recompute(
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if recompute:
......@@ -757,13 +750,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -775,7 +768,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states()
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -805,14 +798,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad:
param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
reset_rng_states()
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
for p in block.parameters():
if p.requires_grad:
......@@ -845,6 +838,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -865,13 +860,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum()
......@@ -891,11 +886,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_attention_heads=config.num_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
......@@ -910,7 +907,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT(
config.hidden_size,
config.eps,
config.num_attention_heads,
config.num_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
......@@ -971,13 +968,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
forward_kwargs = {}
if te:
......@@ -1002,10 +999,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
config.num_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
......@@ -1016,7 +1015,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha = (
TorchMHA(
config.hidden_size,
config.num_attention_heads,
config.num_heads,
)
.to(dtype=dtype)
.cuda()
......@@ -1062,7 +1061,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1094,11 +1093,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
diagonal=1,
)
query, key, value = [
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
(config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1127,8 +1127,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa = (
DotProductAttention(
config.num_attention_heads,
config.embed,
config.num_heads,
config.kv_channels,
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
......@@ -1137,7 +1137,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = (
TorchDotProductAttention(
config.embed,
config.kv_channels,
0.0, # dropout
)
.to(dtype=dtype)
......@@ -1286,7 +1286,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -1726,7 +1726,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1739,14 +1739,14 @@ def _test_grouped_linear_accuracy(
split_size = 16
if recipe.mxfp8():
split_size = 128
m = config.seq_len // split_size
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * split_size
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])
m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, GroupedLinear):
......@@ -1812,7 +1812,7 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -1916,7 +1916,7 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2064,14 +2064,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
(config.max_seqlen_q * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding):
......@@ -2124,7 +2124,7 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2201,7 +2201,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2258,9 +2258,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture.
static_input = torch.randn(
config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(
config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
......@@ -2324,7 +2326,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
......@@ -2332,7 +2334,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2367,13 +2369,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -2382,13 +2384,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
......@@ -2451,13 +2453,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2472,13 +2474,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2490,13 +2492,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2511,15 +2513,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
# To make sure forward is also identical (just in case some module decides
# to act fancy)
......@@ -2546,165 +2548,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len,
max_seqlen_kv=config.seq_len,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
)
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 12, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1
# Limits the max size of KV-cache
B_max = B
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
hidden_size=D,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
device="cuda",
).eval()
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
.eval()
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
)
inference_params = InferenceParams(
max_batch_size=B_max,
max_sequence_length=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
if input_format == "sbhd":
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
torch.float32: 5e-3,
torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32: 1e-3,
torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
......
......@@ -46,7 +46,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import dtype_tols
from utils import ModelConfig, dtype_tols
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -59,8 +59,6 @@ mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available(
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
......@@ -105,37 +103,22 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass
class ModelConfig:
"""Transformer model configuration"""
num_layers: int
seq_len: int
batch_size: int
hidden_size: int
num_attention_heads: int
kv_channels: Optional[int] = None
def is_fp8_supported(self):
if self.seq_len * self.batch_size % 16:
return False
if self.hidden_size % 16:
return False
return True
def is_fp8_supported(config: ModelConfig):
if (
config.max_seqlen_q * config.batch_size % 16
or config.max_seqlen_kv * config.batch_size % 16
):
return False
if config.hidden_size % 16 or config.hidden_size_kv % 16:
return False
return True
model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
"126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
"small": ModelConfig(2, 32, 2, 32, num_layers=2),
"weird": ModelConfig(3, 37, 3, 23, num_layers=2),
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
fp8_recipes = [
......@@ -184,7 +167,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Placeholders used for capture.
static_input = torch.randn(
config.seq_len,
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
device="cuda",
......@@ -192,7 +175,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
......@@ -236,7 +219,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=torch.float32,
device="cuda",
requires_grad=True,
......@@ -244,7 +227,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -271,14 +254,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -311,7 +294,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -337,7 +320,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -345,7 +328,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
(config.batch_size, 1, 1, config.max_seqlen_q),
dtype=torch.bool,
device="cuda",
)
......@@ -363,21 +346,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
(config.batch_size, 1, 1, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -405,7 +388,7 @@ def _test_sanity_common(
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
......@@ -433,7 +416,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
device="cuda",
requires_grad=True,
)
......@@ -494,7 +477,7 @@ def test_sanity_layernorm_linear(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -528,7 +511,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -555,7 +538,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len
num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not fp8_available:
......@@ -564,7 +547,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -600,7 +583,7 @@ def test_sanity_grouped_linear(
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.seq_len * (num_gemms - 1)
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
......@@ -609,7 +592,7 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -621,7 +604,7 @@ def test_sanity_grouped_linear(
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
m_splits = [bs * config.seq_len] * num_gemms
m_splits = [bs * config.max_seqlen_q] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
......@@ -665,7 +648,7 @@ def test_sanity_layernorm_mlp(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -719,7 +702,7 @@ def test_sanity_gpt(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -729,7 +712,7 @@ def test_sanity_gpt(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -788,7 +771,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -798,7 +781,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -849,7 +832,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -859,7 +842,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -908,7 +891,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -918,7 +901,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -945,7 +928,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -955,7 +938,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -985,7 +968,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -995,7 +978,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -1028,7 +1011,7 @@ def test_sanity_gradient_accumulation_fusion(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -1038,7 +1021,7 @@ def test_sanity_gradient_accumulation_fusion(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -1074,7 +1057,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -1084,7 +1067,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -1156,133 +1139,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
"""Test the functionality of replace_raw_data"""
......
......@@ -4,12 +4,24 @@
from __future__ import annotations
import logging
import os
from contextlib import contextmanager
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
......@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def reset_rng_states() -> None:
"""Revert to deterministic RNG state"""
global _rng_states
if _rng_states is None:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
_rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
else:
cpu_rng_state, cuda_rng_state = _rng_states
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
class ModelConfig:
def __init__(
self,
batch_size: int,
max_seqlen_q: int,
num_heads: int,
head_dim_qk: int,
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv
self.num_heads = num_heads
self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
if self.head_dim_qk == self.head_dim_v:
self.kv_channels = self.head_dim_qk
else:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check for all available attention backends that support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
......@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) {
!requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 &&
// 9.10.2: any head_dim + any arch + fprop + paged
// 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91002 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
......@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) &&
// check 64-bit ragged offset support
(supported_ragged_offset_size)) {
(supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment