Unverified Commit 37339478 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Refactoring attention.py part 1 (#1542)



* Create pytorch/dot_product_attention module and pytorch/d_p_a/utils.py
Move attention logging into a separate class in pytorch/d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Create FlashAttentionUtils class in pytorch/d_p_a/utils/py for versioning info
Move versioning info out of pytorch/attention.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move AttentionParams and get_attention_backend from attention.py to d_p_a/utils.py
Fix tests and imports for the above refactor change
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move get_qkv_layout(), get_full_mask(), get_alibi(), get_attention_quantizers() to d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move tensor packing and unpacking helper functions from pyt/attention.py to d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move cumulative seqlens and indices methods from pyt/attention.py to d_p_a/utils.py
Rename cumulative functions from using _cu_ to using _cumul_ to differentiate from CUDA cu calls protocol
Rename tensor packaging methods with leading underscore to make them as internal to file
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary imports in pytorch/attention.py and d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Create d_p_a/inference.py and move InferenceParams from pyt/attention.py to it
Modify tests and other files to import InferenceParams correctly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

Modify docs api for InferenceParams
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Create d_p_a/rope.py and move RoPE methods from  pytorch/attention.py to it
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code cleanup
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix qa testing induced bug
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect pack_tensor arg type
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Resolve lint errors
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove typedef FAUtils for FlashAttentionUtils
Use attn_log instead of att_log
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

Fix lint error
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nit: Fix the function name from get_cumul to the earlier get_cu
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Fix typos, explicit imports and remove extra comments
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent c257bf31
......@@ -31,7 +31,7 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
......
......@@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
......
......@@ -11,7 +11,7 @@ import torch
from torch import nn
import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers
......
......@@ -18,15 +18,15 @@ from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model
from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
_flash_attn_is_installed,
_flash_attn_2_3_plus,
_flash_attn_3_is_installed,
check_set_window_size,
AttentionParams,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
......@@ -191,9 +191,20 @@ def _get_attention_backends(
fp8=fp8,
fp8_meta=fp8_meta,
)
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
......@@ -269,12 +280,12 @@ def test_dot_product_attention(
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and _flash_attn_is_installed
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or _flash_attn_2_3_plus)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True
......@@ -581,7 +592,7 @@ model_configs_swa = {
}
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
......@@ -603,7 +614,7 @@ model_configs_alibi_slopes = {
}
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
......@@ -1445,7 +1456,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1471,7 +1486,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1656,7 +1675,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1685,7 +1708,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......
......@@ -7,15 +7,11 @@ import subprocess
import pytest
import torch
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
model_configs_flash_attn = {
......@@ -54,7 +50,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args
@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
......@@ -5,7 +5,7 @@ import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.attention import (
from transformer_engine.pytorch.dot_product_attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
......
......@@ -34,10 +34,10 @@ from transformer_engine.pytorch import (
RMSNorm,
TransformerLayer,
LayerNorm,
InferenceParams,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
......
......@@ -89,8 +89,8 @@ from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
......
......@@ -12,17 +12,13 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
import functools
from dataclasses import dataclass, fields
import numpy as np
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
......@@ -31,18 +27,9 @@ from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
META_QKV,
META_DQKV,
META_O,
META_DO,
META_S,
META_DP,
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
......@@ -87,47 +74,18 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved,
)
# Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
fa_logger = logging.getLogger(__name__)
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
fa_logger.addHandler(_stream_handler)
@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
return ">= " + str(version_min) + ", " + "<= " + str(version_max)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
_flash_attn_version_required_blackwell = PkgVersion("2.7.3")
_flash_attn_max_version = PkgVersion("2.7.4.post1")
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
_flash_attn_2_7_0_plus = False
# Setup Attention Logging
attn_log.setup_logging()
# Global vars for flash attn imports
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
......@@ -135,23 +93,26 @@ _flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip3 install flash-attn".""",
)
else:
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version:
_flash_attn_is_installed = True
elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
_flash_attn_is_installed = True
if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version:
fa_utils.is_installed = True
elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version:
fa_utils.is_installed = True
if _flash_attn_is_installed:
if fa_utils.is_installed:
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
......@@ -163,51 +124,40 @@ else:
_flash_attn_varlen_backward as _flash_attn_varlen_bwd,
)
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
_flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0")
# Setup Flash attention utils
fa_utils.set_flash_attention_version()
elif (
torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
fa_logger.warning(
attn_log.fa_logger.warning(
"Supported flash-attn versions are %s. Found flash-attn %s.",
_get_supported_versions(
dpa_utils._get_supported_versions(
(
_flash_attn_version_required
fa_utils.version_required
if get_device_compute_capability() < (10, 0)
else _flash_attn_version_required_blackwell
else fa_utils.version_required_blackwell
),
_flash_attn_max_version,
fa_utils.max_version,
),
_flash_attn_version,
fa_utils.version,
)
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
_use_flash_attn_3 = False
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
_flash_attn_3_installation_steps = """\
(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
try:
_flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
_flash_attn_3_installation_steps,
fa_utils.v3_installation_steps,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
......@@ -223,10 +173,9 @@ else:
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
)
_flash_attn_3_is_installed = True
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True
fa_utils.set_flash_attention_3_params()
# Global vars for available attention backends and ALiBi cache
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
......@@ -236,107 +185,6 @@ _attention_backends = {
"backend_selection_requires_update": False,
}
@dataclass(eq=True)
class AttentionParams:
"""
Attention parameters used to determine which backend to be used.
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
Batch size.
num_heads: int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64
The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
qkv_dtype: torch.dtype = torch.bfloat16
qkv_layout: str = "sbh3d"
batch_size: int = 1
num_heads: int = 16
num_gqa_groups: int = 16
max_seqlen_q: int = 128
max_seqlen_kv: int = 128
head_dim_qk: int = 64
head_dim_v: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
core_attention_bias_requires_grad: bool = True
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
def __eq__(self, other):
"""
Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared,
since all other entries of fp8_meta are unused in get_attention_backend.
"""
if not isinstance(other, self.__class__):
return NotImplemented
for field in fields(self):
fname = field.name
sf = getattr(self, fname)
of = getattr(other, fname)
if fname != "fp8_meta":
if sf != of:
return False
elif sf.get("recipe", None) != of.get("recipe", None):
return False
return True
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
......@@ -348,8 +196,7 @@ _alibi_cache = {
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
__all__ = ["DotProductAttention", "MultiheadAttention"]
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
......@@ -357,1196 +204,6 @@ def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
return tensor.contiguous() if tensor.stride(-1) != 1 else tensor
def get_attention_backend(
attention_params: AttentionParams = None,
):
"""
Select the appropriate attention backend/sub-backend based on user input and runtime environment.
Parameters
----------
See `AttentionParams`.
Returns
----------
use_flash_attention: bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
qkv_type = attention_params.qkv_type
qkv_dtype = attention_params.qkv_dtype
qkv_layout = attention_params.qkv_layout
batch_size = attention_params.batch_size
num_heads = attention_params.num_heads
num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv
head_dim_qk = attention_params.head_dim_qk
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
logger.addHandler(_stream_handler)
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
"transformer_engine_version": te.__version__,
"compute_capability": "sm"
+ str(10 * device_compute_capability[0] + device_compute_capability[1]),
"flash_attn_version": (
str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
),
"flash_attn_3_version": (
str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
),
"cudnn_version": ".".join([str(i) for i in cudnn_version]),
}
attention_params_dict = {
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
if fp8:
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
# Filter: Compute capability
if device_compute_capability < (8, 0):
if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False
# Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention and _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not _use_flash_attn_3:
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_unfused_attention = False
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (
head_dim_qk > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
):
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
# bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention
# | no_mask, causal | |
# | cross-attention: | |
# | no_mask | |
# thd | self-attention: | no_bias | FlashAttention, FusedAttention
# | padding, padding_causal | | if no padding between sequences,
# | cross-attention: | | FusedAttention
# | padding | | if there is padding between sequences
# Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if fp8 and fp8_meta["recipe"].fp8_dpa:
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_fused_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_fused_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_fused_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# padding_causal | Same as "padding" |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_flash_attention
and _use_flash_attn_3
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
_use_flash_attn_3 = False
if (
use_flash_attention
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
if _flash_attn_2_1_plus:
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not _flash_attn_is_installed:
_flash_attn_max_version = PkgVersion("2.1")
if (
use_flash_attention
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.1")
elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and _use_flash_attn_3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
_use_flash_attn_3 = False
# Filter: Sliding window attention
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with (left, 0) and no dropout"
)
use_fused_attention = False
elif max_seqlen_q > max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if _use_flash_attn_3:
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.3")
elif not _flash_attn_2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | no_bias, alibi/alibi_slopes | bottom right
# FusedAttention | no_bias, post_scale_bias |
# | alibi/alibi_slopes | top left,
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.4")
elif not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
if _flash_attn_is_installed:
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias_shape = core_attention_bias_shape
fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads
):
fu_core_attention_bias_shape = "bhss"
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
if fu_core_attention_bias_requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
# Filter: cuDNN support
fused_attention_backend = None
if use_fused_attention:
q_type = TE_DType[qkv_dtype]
kv_type = q_type
if fp8 and fp8_meta["recipe"].fp8_dpa:
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim_qk,
head_dim_v,
window_size[0],
window_size[1],
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and window_size is not None
and window_size[0] != -1
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"slidng window attention",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
logger.debug(
"Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
" [1, H, S, S] shape"
)
use_fused_attention = False
fused_attention_backend = None
# Filter: Determinism
# backend | deterministic
# ---------------------------------------------
# FlashAttention |
# flash-attn >=2.0, <2.4.1 | no
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic:
if not _flash_attn_is_installed:
_flash_attn_version_required = PkgVersion("2.4.1")
elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
and (
device_compute_capability < (9, 0)
or core_attention_bias_requires_grad
or cudnn_version < (8, 9, 5)
)
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
_get_supported_versions(
_flash_attn_version_required,
_flash_attn_max_version,
),
)
if use_flash_attention and not _flash_attn_is_installed:
use_flash_attention = False
available_backends[0] = False
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
if fused_attention_backend is not None
else ""
),
bool(available_backends[2]),
)
# Select FusedAttention for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability >= (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
use_fused_attention = False
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", selected_backend)
global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return (
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
)
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
@torch.no_grad()
def get_full_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor:
"""
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment
.. note::
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::
[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
"""
# perform basic checks
change_type = window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
)
if window_size is None:
window_size = (-1, -1)
if "causal" in attn_mask_type:
window_size = (window_size[0], 0)
window_size = (
max_seqlen_kv if window_size[0] == -1 else window_size[0],
max_seqlen_q if window_size[1] == -1 else window_size[1],
)
# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)
# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment
):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
attention_mask = swa_mask
# change mask type
if change_type:
attn_mask_type = "arbitrary"
return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
@torch.no_grad()
def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
actual_seqlens_q: Optional[torch.Tensor] = None,
actual_seqlens_kv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
----------
num_heads: int
Number of heads.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns
----------
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
(2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
[batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
`actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
if alibi_slopes is not None:
_alibi_cache["_alibi_slopes"] = alibi_slopes
else:
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
_alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
_alibi_cache["_num_heads"] = num_heads
_alibi_cache["_alibi_slopes_require_update"] = False
if _alibi_cache["_alibi_bias_require_update"]:
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
elif _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
else:
raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if actual_seqlens_q is None and actual_seqlens_kv is None:
if bottom_right_alignment:
bias = bias + max_seqlen_kv - max_seqlen_q
elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
batch_size = actual_seqlens_q.shape[0]
bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
if bottom_right_alignment:
bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
else:
assert (
False
), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
_alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
containing the indices for the valid tokens.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.logical_not().nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
pad_amount = bs * seqlen - num_nonzeros
indices = F.pad(
input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
)
return cu_seqlens, indices
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
"""
Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
the valid tokens in a batch.
"""
bs = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
num_nonzeros = indices.shape[0]
pad_amount = bs * max_seqlen - num_nonzeros
indices = F.pad(
input=indices,
pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant",
value=float(bs * max_seqlen),
)
return indices
_cu_seqlens_cache = {}
def _get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
device: torch.device,
) -> torch.Tensor:
"""Cumulative sequence lengths in full data batch
All sequences in batch have the maximum sequence length.
"""
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Packs the given tensor using the `indices`.
"""
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
if isinstance(tensor, Float8Tensor):
tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
gathered_data = torch.gather(tensor_data, 0, indices)
packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape)
else:
tensor = torch.cat((tensor, padding_indice), dim=0)
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Packs the given 2 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
return t1_packed, t2_packed
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs the given 3 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
t3_packed = pack_tensor(indices, t3)
return t1_packed, t2_packed, t3_packed
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Inverse of `pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
if isinstance(tensor, Float8Tensor):
unpacked.scatter_(0, indices, tensor._data)
unpacked_data = unpacked[0:-1, :, :]
unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape)
else:
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1, :, :]
return unpacked
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_2_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
return t1_unpacked, t2_unpacked
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_3_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
t3_unpacked = unpack_tensor(indices, dim0, t3)
return t1_unpacked, t2_unpacked, t3_unpacked
class PackTensors(torch.autograd.Function):
"""
Autograd function to pack tensors.
"""
@staticmethod
def forward(
ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.save_for_backward(indices)
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return pack_tensor(indices, *tensors)
if len(tensors) == 2:
return pack_2_tensors(indices, *tensors)
return pack_3_tensors(indices, *tensors)
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
if len(grad_outputs) == 1:
return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
"""
Autograd function to unpack a tensor.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
ctx.save_for_backward(indices)
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
return None, None, pack_tensor(indices, grad_output)
def flash_attn_p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
......@@ -1811,49 +468,6 @@ def flash_attn_a2a_communicate(
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
"""Get the list of quantizers used in attention from the quantizers list."""
if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6
return [None] * num_of_nones
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False)
O_quantizer = quantizers["scaling_fwd"][META_O]
O_quantizer.set_usage(rowwise=True, columnwise=False)
S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
dO_quantizer = quantizers["scaling_bwd"][META_DO]
dO_quantizer.set_usage(rowwise=True, columnwise=False)
dO_quantizer.internal = True
dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False)
if cp_specific_quantizers:
return (
QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer,
dQKV_quantizer,
dQKV_CP_quantizer,
dO_quantizer,
dP_quantizer,
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
_cu_seqlens_info_with_cp_cache = {}
......@@ -1988,7 +602,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dQKV_CP_quantizer,
dO_quantizer,
dP_quantizer,
) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)
) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)
if fp8:
if use_fused_attention:
......@@ -2071,12 +685,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
else:
softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or fa_utils.use_v3
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
......@@ -2089,16 +703,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3:
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus) or fa_utils.use_v3:
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = 0 if causal else -1
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus and qkv_format == "thd":
if fa_utils.v2_5_7_plus and qkv_format == "thd":
fa_forward_kwargs["block_table"] = None
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_forward_kwargs["softcap"] = 0.0
# Flash Attn inputs
......@@ -2266,15 +880,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal=True,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[3]
elif i <= rank:
if pad_between_seqs:
......@@ -2380,11 +994,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q,
max_seqlen_kv // 2,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
if fa_utils.use_v3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_forward_kwargs["window_size"] = (-1, -1)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
fa_outputs = flash_attn_fwd(
......@@ -2403,15 +1017,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal=False,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs:
......@@ -2526,11 +1140,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q // 2,
max_seqlen_kv,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
if fa_utils.use_v3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_forward_kwargs["window_size"] = (-1, -1)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
fa_outputs = flash_attn_fwd(
......@@ -2549,15 +1163,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal=False,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs:
......@@ -2664,15 +1278,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal=False,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[3]
if i > 0:
......@@ -3036,7 +1650,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
......@@ -3048,11 +1662,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
if fa_utils.v2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_backward_kwargs["softcap"] = 0.0
for i in range(cp_size):
......@@ -3186,14 +1800,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, 0)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = 0
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -3303,14 +1915,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_7_0_plus:
if fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -3422,14 +2032,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -3518,12 +2126,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1)
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout,
......@@ -3851,13 +2459,13 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
use_fused_attention or _flash_attn_2_3_plus
use_fused_attention or fa_utils.v2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
......@@ -3869,11 +2477,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus and qkv_format == "thd":
if fa_utils.v2_5_7_plus and qkv_format == "thd":
fa_forward_kwargs["block_table"] = None
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_forward_kwargs["softcap"] = 0.0
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
......@@ -3947,7 +2555,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
)
max_seqlen_kv_ = seq_end_idx - seq_start_idx
if use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens(
k.shape[1], max_seqlen_kv_, k.device
)
k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
......@@ -3984,11 +2592,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
max_seqlen_q,
max_seqlen_kv_,
]
if _use_flash_attn_3 or (
_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_forward_kwargs["window_size"] = window_size_per_step[i]
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
fa_outputs = flash_attn_fwd(
......@@ -3999,15 +2605,15 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
causal=causal,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
rng_states[i] = fa_outputs[3]
if i > 0:
......@@ -4107,7 +2713,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
......@@ -4119,11 +2725,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
if fa_utils.v2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_backward_kwargs["softcap"] = 0.0
for i in range(len(local_seq_chunk_ids) + 1):
......@@ -4180,11 +2786,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.max_seqlen_q,
max_seqlen_kv,
]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_states[i]
if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus:
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = window_size_per_step[i]
if _flash_attn_2_7_0_plus:
if fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
flash_attn_bwd(
......@@ -4318,13 +2924,13 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
window_size == (-1, 0)
or window_size == (-1, -1)
or use_fused_attention
or _flash_attn_2_3_plus
or fa_utils.v2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
......@@ -4337,16 +2943,16 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_forward_kwargs["window_size"] = window_size
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size[0]
fa_forward_kwargs["window_size_right"] = window_size[1]
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus and qkv_format == "thd":
if fa_utils.v2_5_7_plus and qkv_format == "thd":
fa_forward_kwargs["block_table"] = None
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_forward_kwargs["softcap"] = 0.0
assert (
......@@ -4368,7 +2974,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
is_output_fp8 = False
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
)
if fp8:
if use_fused_attention:
......@@ -4458,12 +3064,12 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
causal=causal,
**fa_forward_kwargs,
)
if not _flash_attn_2_7_0_plus:
if not fa_utils.v2_7_0_plus:
out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
rng_state = fa_outputs[7] if not fa_utils.use_v3 else None
else:
out, softmax_lse = fa_outputs[0], fa_outputs[1]
rng_state = fa_outputs[3] if not _use_flash_attn_3 else None
rng_state = fa_outputs[3] if not fa_utils.use_v3 else None
aux_ctx_tensors = [softmax_lse, rng_state]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
......@@ -4634,7 +3240,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
......@@ -4647,16 +3253,16 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = ctx.window_size
elif _flash_attn_2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
if fa_utils.v2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_0_plus:
if fa_utils.v2_6_0_plus:
fa_backward_kwargs["softcap"] = 0.0
if ctx.use_fused_attention:
......@@ -4723,7 +3329,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if not _use_flash_attn_3:
if not fa_utils.use_v3:
fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd(
dout,
......@@ -4904,221 +3510,6 @@ def attn_forward_func_with_cp(
return out
class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if (
max_seq_len
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
class _SplitAlongDim(torch.autograd.Function):
""""""
......@@ -5320,13 +3711,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.shape[0],
)
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
max_seqlen_q,
max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask(
max_seqlen_q,
max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
)
)
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
......@@ -5392,7 +3785,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
if core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
if core_attention_bias_type == "alibi":
_, core_attention_bias = get_alibi(
_, core_attention_bias = dpa_utils.get_alibi(
_alibi_cache,
output_size[1],
output_size[2],
output_size[3],
......@@ -5501,202 +3895,6 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv
def get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_format: str = "sbhd",
) -> str:
"""Get qkv layout.
Parameters
----------
q: torch.Tensor
Query tensor.
k: torch.Tensor
Key tensor.
v: torch.Tensor
Value tensor.
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
Returns
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
`v = kv[:,:,:,1,:]`.
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
"""
check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
def run_iteratively(q, k, v):
# check data pointers
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
# check tensor shapes
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]
# check tensor strides
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
sv / v.shape[-1] for sv in v.stride()[:-1]
)
# check tensor offsets for h3d and 3hd layouts
prod_h_d = q.shape[-1] * q.shape[-2]
check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
check_h3d_offsets = all(
x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
)
# check tensor offsets for hd_h2d and hd_2hd layouts
prod_all_dims = [np.prod(x.shape) for x in [q, k]]
offset = prod_all_dims[0] if check_ptrs_qkv else 0
prod_h_d = k.shape[-1] * k.shape[-2]
check_2hd_offsets = all(
x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
)
check_h2d_offsets = all(
x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
)
# check tensor offsets for hd_hd_hd layouts
check_hd_offsets_qkv = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
if check_ptrs_qkv
else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
)
check_hd_offsets_qk = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
if not check_ptrs_qkv and check_ptrs_qk
else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
)
check_hd_offsets_kv = (
all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
if not check_ptrs_qkv and check_ptrs_kv
else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
)
if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
# sb3hd, bs3hd, t3hd
# one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
# sbh3d, bsh3d, th3d
# one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
# two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
# two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
elif (
check_strides_kv
and check_shapes_kv
and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
):
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = "not_supported"
return qkv_layout
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported":
raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v
def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
):
"""Check if sliding window size is compliant with attention mask type.
If not, set it to the appropriate size.
attn_mask_type | window_size
-------------------------------------------------------------------------
no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if orig_window_size is None:
window_size = (-1, 0)
elif orig_window_size == (-1, -1) or (
orig_window_size[0] >= 0 and orig_window_size[1] != 0
):
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None:
window_size = (-1, -1)
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
else:
assert False, "Invalid attn_mask_type: " + attn_mask_type
return window_size
class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package:
https://github.com/Dao-AILab/flash-attention
......@@ -5713,13 +3911,13 @@ class FlashAttention(torch.nn.Module):
) -> None:
super().__init__()
if _flash_attn_is_installed:
if fa_utils.is_installed:
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
fa_utils.version >= fa_utils.version_required
), f"FlashAttention minimum version {fa_utils.version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
fa_utils.version <= fa_utils.max_version
), f"FlashAttention maximum version {fa_utils.max_version} is supported."
self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx
......@@ -5728,9 +3926,9 @@ class FlashAttention(torch.nn.Module):
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
self.logger.setLevel(attn_log._log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.logger.addHandler(attn_log._stream_handler)
def forward(
self,
......@@ -5834,11 +4032,13 @@ class FlashAttention(torch.nn.Module):
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
attention_mask
)
else:
indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
cu_seqlens_kv = cu_seqlens_q
query_layer, key_layer, value_layer = PackTensors.apply(
query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply(
indices_q, query_layer, key_layer, value_layer
)
else:
......@@ -5846,23 +4046,29 @@ class FlashAttention(torch.nn.Module):
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
attention_mask[0]
)
cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices(
attention_mask[1]
)
else:
indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer = PackTensors.apply(indices_q, query_layer)
key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer)
key_layer, value_layer = dpa_utils.PackTensors.apply(
indices_kv, key_layer, value_layer
)
else:
# Cumulative sequence lengths for unpadded data
if cu_seqlens_q is None:
cu_seqlens_q = _get_full_cu_seqlens(
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
if cu_seqlens_kv is None:
cu_seqlens_kv = _get_full_cu_seqlens(
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
......@@ -5921,28 +4127,26 @@ class FlashAttention(torch.nn.Module):
with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
if fa_utils.v2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus:
if fa_utils.v2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if _flash_attn_2_4_1_plus:
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
func = flash_attn_func if not fa_utils.use_v3 else flash_attn_func_v3
else:
if _flash_attn_2_5_7_plus:
if fa_utils.v2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
func = (
flash_attn_varlen_func
if not _use_flash_attn_3
else flash_attn_varlen_func_v3
flash_attn_varlen_func if not fa_utils.use_v3 else flash_attn_varlen_func_v3
)
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
if fa_utils.use_v3:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
......@@ -5994,12 +4198,12 @@ class FlashAttention(torch.nn.Module):
**fa_3_optional_forward_kwargs,
)
except TypeError as e:
if _flash_attn_3_0_0_beta:
if fa_utils.v3_0_0_beta:
e.args = (
e.args[0]
+ ". Please update your flash-attn v3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps,
+ fa_utils.v3_installation_steps,
) + e.args[1:]
raise
......@@ -6021,7 +4225,7 @@ class FlashAttention(torch.nn.Module):
)
if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
......@@ -6122,7 +4326,7 @@ class FusedAttnFunc(torch.autograd.Function):
fake_dtype = q.dtype
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
)
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
......@@ -6679,20 +4883,20 @@ class FusedAttention(torch.nn.Module):
"Please provide attention_mask or cu_seqlens for padding!"
)
if self.attention_type == "self":
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
if cu_seqlens_q is None:
cu_seqlens_q = _get_full_cu_seqlens(
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
if cu_seqlens_kv is None:
cu_seqlens_kv = _get_full_cu_seqlens(
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
......@@ -6953,15 +5157,15 @@ class DotProductAttention(TransformerEngineBaseModule):
super().__init__()
self.logger = logging.getLogger("DotProductAttention")
self.logger.setLevel(_log_level)
self.logger.setLevel(attn_log._log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.logger.addHandler(attn_log._stream_handler)
self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.window_size = check_set_window_size(attn_mask_type, window_size)
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -7418,7 +5622,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
......@@ -7552,18 +5756,18 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_q = _get_full_cu_seqlens(
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
cu_seqlens_kv = _get_full_cu_seqlens(
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
......@@ -7574,11 +5778,13 @@ class DotProductAttention(TransformerEngineBaseModule):
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
qkv_layout, query_layer._data, key_layer._data, value_layer._data = (
dpa_utils.get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
)
else:
qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
qkv_layout, query_layer, key_layer, value_layer = dpa_utils.get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
......@@ -7640,7 +5846,7 @@ class DotProductAttention(TransformerEngineBaseModule):
else:
pad_between_seqs = False
attention_params = AttentionParams(
attention_params = dpa_utils.AttentionParams(
qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout,
......@@ -7667,7 +5873,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
global _attention_backends, _use_flash_attn_3
global _attention_backends
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......@@ -7675,18 +5881,25 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_is_installed
fa_utils.use_v3 = fa_utils.v3_is_installed
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = get_attention_backend(attention_params)
) = dpa_utils.get_attention_backend(attention_params)
# Set global _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
fa_utils.version if not fa_utils.use_v3 else fa_utils.fa3_version,
)
elif use_fused_attention:
self.logger.info(
......@@ -7703,7 +5916,8 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
alibi_slopes, _ = dpa_utils.get_alibi(
_alibi_cache,
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
......@@ -7738,7 +5952,8 @@ class DotProductAttention(TransformerEngineBaseModule):
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
_, fu_core_attention_bias = dpa_utils.get_alibi(
_alibi_cache,
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
......@@ -8025,7 +6240,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.window_size = check_set_window_size(attn_mask_type, window_size)
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.layer_number = layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -8385,7 +6600,7 @@ class MultiheadAttention(torch.nn.Module):
attn_mask_type = self.attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if "padding" in attn_mask_type and attention_mask is not None:
for mask in attention_mask:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for dot product attention"""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Inference classes for attention
"""
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if (
max_seq_len
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utils/Helper classes and methods for attention
"""
import math
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import warnings
import logging
import functools
from dataclasses import dataclass, fields
import numpy as np
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
META_QKV,
META_DQKV,
META_O,
META_DO,
META_S,
META_DP,
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.jit import jit_fuser
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
class AttentionLogging:
"""
Manage logging for attention module
"""
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
fa_logger = logging.getLogger(__name__)
@staticmethod
def setup_logging():
"""
Set up log levels, logger and handlers
"""
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
AttentionLogging._log_level = _log_levels[
AttentionLogging._log_level if AttentionLogging._log_level in [0, 1, 2] else 2
]
AttentionLogging._stream_handler.setFormatter(AttentionLogging._formatter)
AttentionLogging.fa_logger.setLevel(AttentionLogging._log_level)
if not AttentionLogging.fa_logger.hasHandlers():
AttentionLogging.fa_logger.addHandler(AttentionLogging._stream_handler)
@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
"""
Calculate version info based on min and max numbers
"""
return ">= " + str(version_min) + ", " + "<= " + str(version_max)
class FlashAttentionUtils:
"""
Manage Flash Attention versioning information
"""
# Detect flash-attn v2 in the environment
is_installed = False
version = PkgVersion("0")
version_required = PkgVersion("2.1.1")
version_required_blackwell = PkgVersion("2.7.3")
max_version = PkgVersion("2.7.4.post1")
v2_plus = False
v2_1_plus = False
v2_3_plus = False
v2_4_plus = False
v2_4_1_plus = False
v2_5_7_plus = False
v2_6_0_plus = False
v2_7_0_plus = False
v3_is_installed = False
fa3_version = PkgVersion("0")
v3_0_0_beta = False
use_v3 = False
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
v3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
@staticmethod
def set_flash_attention_version():
"""
Setup version info for FA v2.x
"""
FlashAttentionUtils.is_installed = True
FlashAttentionUtils.v2_plus = FlashAttentionUtils.version >= PkgVersion("2")
FlashAttentionUtils.v2_1_plus = FlashAttentionUtils.version >= PkgVersion("2.1")
FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3")
FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4")
FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1")
FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7")
FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0")
FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0")
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
@staticmethod
def set_flash_attention_3_params():
"""
Setup version info for FA v3.x
"""
FlashAttentionUtils.v3_is_installed = True
FlashAttentionUtils.v3_0_0_beta = (
PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0")
)
FlashAttentionUtils.use_v3 = True
@dataclass(eq=True)
class AttentionParams:
"""
Attention parameters used to determine which backend to be used.
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
Batch size.
num_heads: int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64
The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
qkv_dtype: torch.dtype = torch.bfloat16
qkv_layout: str = "sbh3d"
batch_size: int = 1
num_heads: int = 16
num_gqa_groups: int = 16
max_seqlen_q: int = 128
max_seqlen_kv: int = 128
head_dim_qk: int = 64
head_dim_v: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
core_attention_bias_requires_grad: bool = True
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
def __eq__(self, other):
"""
Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared,
since all other entries of fp8_meta are unused in get_attention_backend.
"""
if not isinstance(other, self.__class__):
return NotImplemented
for field in fields(self):
fname = field.name
sf = getattr(self, fname)
of = getattr(other, fname)
if fname != "fp8_meta":
if sf != of:
return False
elif sf.get("recipe", None) != of.get("recipe", None):
return False
return True
def get_attention_backend(
attention_params: AttentionParams = None,
):
"""
Select the appropriate attention backend/sub-backend based on user input and runtime environment.
Parameters
----------
See `AttentionParams`.
Returns
----------
use_flash_attention: bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
# NOTE: As part of refactoring attention.py, populating the _attention_backends cache in attention
# is no longer performed at the end of get_attention_backend(), but the responsibility of doing so
# is shifted over to the caller of this function
qkv_type = attention_params.qkv_type
qkv_dtype = attention_params.qkv_dtype
qkv_layout = attention_params.qkv_layout
batch_size = attention_params.batch_size
num_heads = attention_params.num_heads
num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv
head_dim_qk = attention_params.head_dim_qk
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(AttentionLogging._log_level)
if not logger.hasHandlers():
logger.addHandler(AttentionLogging._stream_handler)
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
"transformer_engine_version": te.__version__,
"compute_capability": "sm"
+ str(10 * device_compute_capability[0] + device_compute_capability[1]),
"flash_attn_version": (
str(FlashAttentionUtils.version)
if FlashAttentionUtils.is_installed
else "not installed"
),
"flash_attn_3_version": (
str(FlashAttentionUtils.fa3_version)
if FlashAttentionUtils.v3_is_installed
else "not installed"
),
"cudnn_version": ".".join([str(i) for i in cudnn_version]),
}
attention_params_dict = {
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
if fp8:
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
# Filter: Compute capability
if device_compute_capability < (8, 0):
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
FlashAttentionUtils.use_v3 = False
# Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not FlashAttentionUtils.use_v3:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and FlashAttentionUtils.use_v3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_unfused_attention = False
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (
head_dim_qk > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3:
logger.debug("Disabling FlashAttention 3 for dropout")
FlashAttentionUtils.use_v3 = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
# bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention
# | no_mask, causal | |
# | cross-attention: | |
# | no_mask | |
# thd | self-attention: | no_bias | FlashAttention, FusedAttention
# | padding, padding_causal | | if no padding between sequences,
# | cross-attention: | | FusedAttention
# | padding | | if there is padding between sequences
# Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if fp8 and fp8_meta["recipe"].fp8_dpa:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_fused_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_fused_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_fused_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# padding_causal | Same as "padding" |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
FlashAttentionUtils.use_v3 = False
if (
use_flash_attention
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
if FlashAttentionUtils.v2_1_plus:
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.max_version = PkgVersion("2.1")
if (
use_flash_attention
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.1")
elif not FlashAttentionUtils.v2_1_plus and not FlashAttentionUtils.use_v3:
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
FlashAttentionUtils.use_v3 = False
# Filter: Sliding window attention
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with (left, 0) and no dropout"
)
use_fused_attention = False
elif max_seqlen_q > max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if FlashAttentionUtils.use_v3:
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
FlashAttentionUtils.use_v3 = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.3")
elif not FlashAttentionUtils.v2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | no_bias, alibi/alibi_slopes | bottom right
# FusedAttention | no_bias, post_scale_bias |
# | alibi/alibi_slopes | top left,
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if FlashAttentionUtils.use_v3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
FlashAttentionUtils.use_v3 = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4")
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias_shape = core_attention_bias_shape
fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads
):
fu_core_attention_bias_shape = "bhss"
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
if fu_core_attention_bias_requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
# Filter: cuDNN support
fused_attention_backend = None
if use_fused_attention:
q_type = TE_DType[qkv_dtype]
kv_type = q_type
if fp8 and fp8_meta["recipe"].fp8_dpa:
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim_qk,
head_dim_v,
window_size[0],
window_size[1],
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and window_size is not None
and window_size[0] != -1
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"slidng window attention",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
logger.debug(
"Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
" [1, H, S, S] shape"
)
use_fused_attention = False
fused_attention_backend = None
# Filter: Determinism
# backend | deterministic
# ---------------------------------------------
# FlashAttention |
# flash-attn >=2.0, <2.4.1 | no
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4.1")
elif not FlashAttentionUtils.v2_4_1_plus and not FlashAttentionUtils.use_v3:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
and (
device_compute_capability < (9, 0)
or core_attention_bias_requires_grad
or cudnn_version < (8, 9, 5)
)
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not FlashAttentionUtils.is_installed:
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
_get_supported_versions(
FlashAttentionUtils.version_required,
FlashAttentionUtils.max_version,
),
)
if use_flash_attention and not FlashAttentionUtils.is_installed:
use_flash_attention = False
available_backends[0] = False
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
if fused_attention_backend is not None
else ""
),
bool(available_backends[2]),
)
# Select FusedAttention for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability >= (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and FlashAttentionUtils.use_v3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
use_fused_attention = False
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", selected_backend)
"""global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False"""
return (
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
)
@torch.no_grad()
def get_full_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor:
"""
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment
.. note::
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::
[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
"""
# perform basic checks
change_type = window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
)
if window_size is None:
window_size = (-1, -1)
if "causal" in attn_mask_type:
window_size = (window_size[0], 0)
window_size = (
max_seqlen_kv if window_size[0] == -1 else window_size[0],
max_seqlen_q if window_size[1] == -1 else window_size[1],
)
# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)
# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment
):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
attention_mask = swa_mask
# change mask type
if change_type:
attn_mask_type = "arbitrary"
return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
@torch.no_grad()
def get_alibi(
_alibi_cache: Dict[str, Any],
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
actual_seqlens_q: Optional[torch.Tensor] = None,
actual_seqlens_kv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
----------
num_heads: int
Number of heads.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns
----------
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
(2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
[batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
`actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
"""
# NOTE: As part of refactoring attention.py, get_alibi() now receives the alibi cache from the caller
# as an additional input arg
if _alibi_cache["_alibi_slopes_require_update"]:
if alibi_slopes is not None:
_alibi_cache["_alibi_slopes"] = alibi_slopes
else:
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
_alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
_alibi_cache["_num_heads"] = num_heads
_alibi_cache["_alibi_slopes_require_update"] = False
if _alibi_cache["_alibi_bias_require_update"]:
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
elif _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
else:
raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if actual_seqlens_q is None and actual_seqlens_kv is None:
if bottom_right_alignment:
bias = bias + max_seqlen_kv - max_seqlen_q
elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
batch_size = actual_seqlens_q.shape[0]
bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
if bottom_right_alignment:
bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
else:
assert (
False
), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
_alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
containing the indices for the valid tokens.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.logical_not().nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
pad_amount = bs * seqlen - num_nonzeros
indices = F.pad(
input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
)
return cu_seqlens, indices
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
"""
Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
the valid tokens in a batch.
"""
bs = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
num_nonzeros = indices.shape[0]
pad_amount = bs * max_seqlen - num_nonzeros
indices = F.pad(
input=indices,
pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant",
value=float(bs * max_seqlen),
)
return indices
_cu_seqlens_cache = {}
def get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
device: torch.device,
) -> torch.Tensor:
"""Cumulative sequence lengths in full data batch
All sequences in batch have the maximum sequence length.
"""
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser
def _pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Packs the given tensor using the `indices`.
"""
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
if isinstance(tensor, Float8Tensor):
tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
gathered_data = torch.gather(tensor_data, 0, indices)
packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape)
else:
tensor = torch.cat((tensor, padding_indice), dim=0)
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
def _pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Packs the given 2 tensors using the `indices`.
"""
t1_packed = _pack_tensor(indices, t1)
t2_packed = _pack_tensor(indices, t2)
return t1_packed, t2_packed
@jit_fuser
def _pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs the given 3 tensors using the `indices`.
"""
t1_packed = _pack_tensor(indices, t1)
t2_packed = _pack_tensor(indices, t2)
t3_packed = _pack_tensor(indices, t3)
return t1_packed, t2_packed, t3_packed
@jit_fuser
def _unpack_tensor(
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Inverse of `_pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
if isinstance(tensor, Float8Tensor):
unpacked.scatter_(0, indices, tensor._data)
unpacked_data = unpacked[0:-1, :, :]
unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape)
else:
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1, :, :]
return unpacked
@jit_fuser
def _unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Inverse of `_pack_2_tensors`.
"""
t1_unpacked = _unpack_tensor(indices, dim0, t1)
t2_unpacked = _unpack_tensor(indices, dim0, t2)
return t1_unpacked, t2_unpacked
@jit_fuser
def _unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inverse of `_pack_3_tensors`.
"""
t1_unpacked = _unpack_tensor(indices, dim0, t1)
t2_unpacked = _unpack_tensor(indices, dim0, t2)
t3_unpacked = _unpack_tensor(indices, dim0, t3)
return t1_unpacked, t2_unpacked, t3_unpacked
class PackTensors(torch.autograd.Function):
"""
Autograd function to pack a tensor.
"""
@staticmethod
def forward(
ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.save_for_backward(indices)
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return _pack_tensor(indices, *tensors)
if len(tensors) == 2:
return _pack_2_tensors(indices, *tensors)
return _pack_3_tensors(indices, *tensors)
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
if len(grad_outputs) == 1:
return None, _unpack_tensor(indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *_unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
return None, *_unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
"""
Autograd function to unpack a tensor.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
ctx.save_for_backward(indices)
return _unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
(indices,) = ctx.saved_tensors
return None, None, _pack_tensor(indices, grad_output)
def get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_format: str = "sbhd",
) -> str:
"""Get qkv layout.
Parameters
----------
q: torch.Tensor
Query tensor.
k: torch.Tensor
Key tensor.
v: torch.Tensor
Value tensor.
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
Returns
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
`v = kv[:,:,:,1,:]`.
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
"""
check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
def run_iteratively(q, k, v):
# check data pointers
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
# check tensor shapes
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]
# check tensor strides
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
sv / v.shape[-1] for sv in v.stride()[:-1]
)
# check tensor offsets for h3d and 3hd layouts
prod_h_d = q.shape[-1] * q.shape[-2]
check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
check_h3d_offsets = all(
x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
)
# check tensor offsets for hd_h2d and hd_2hd layouts
prod_all_dims = [np.prod(x.shape) for x in [q, k]]
offset = prod_all_dims[0] if check_ptrs_qkv else 0
prod_h_d = k.shape[-1] * k.shape[-2]
check_2hd_offsets = all(
x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
)
check_h2d_offsets = all(
x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
)
# check tensor offsets for hd_hd_hd layouts
check_hd_offsets_qkv = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
if check_ptrs_qkv
else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
)
check_hd_offsets_qk = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
if not check_ptrs_qkv and check_ptrs_qk
else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
)
check_hd_offsets_kv = (
all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
if not check_ptrs_qkv and check_ptrs_kv
else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
)
if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
# sb3hd, bs3hd, t3hd
# one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
# sbh3d, bsh3d, th3d
# one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
# two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
# two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
elif (
check_strides_kv
and check_shapes_kv
and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
):
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = "not_supported"
return qkv_layout
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == "not_supported":
raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v
def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
):
"""Check if sliding window size is compliant with attention mask type.
If not, set it to the appropriate size.
attn_mask_type | window_size
-------------------------------------------------------------------------
no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if orig_window_size is None:
window_size = (-1, 0)
elif orig_window_size == (-1, -1) or (
orig_window_size[0] >= 0 and orig_window_size[1] != 0
):
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None:
window_size = (-1, -1)
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
else:
assert False, "Invalid attn_mask_type: " + attn_mask_type
return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
"""Get the list of quantizers used in attention from the quantizers list."""
if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6
return [None] * num_of_nones
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False)
O_quantizer = quantizers["scaling_fwd"][META_O]
O_quantizer.set_usage(rowwise=True, columnwise=False)
S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
dO_quantizer = quantizers["scaling_bwd"][META_DO]
dO_quantizer.set_usage(rowwise=True, columnwise=False)
dO_quantizer.internal = True
dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False)
if cp_specific_quantizers:
return (
QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer,
dQKV_quantizer,
dQKV_CP_quantizer,
dO_quantizer,
dP_quantizer,
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
......@@ -12,10 +12,10 @@ import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import (
InferenceParams,
MultiheadAttention,
check_set_window_size,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment