Unverified Commit 250f5cb5 authored by Mikko Lauri's avatar Mikko Lauri Committed by GitHub
Browse files

Add AITER attention backend (#12549)



* add aiter attention backend

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent dc6bd151
...@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and ...@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature | | attention family | main feature |
|---|---| |---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation | | FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 | | SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) | | PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels | | xFormers | memory-efficient attention with support for various attention kernels |
...@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and ...@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention | | `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 | | `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
......
...@@ -27,6 +27,8 @@ if torch.distributed.is_available(): ...@@ -27,6 +27,8 @@ if torch.distributed.is_available():
from ..utils import ( from ..utils import (
get_logger, get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available, is_flash_attn_3_available,
is_flash_attn_available, is_flash_attn_available,
is_flash_attn_version, is_flash_attn_version,
...@@ -47,6 +49,7 @@ if TYPE_CHECKING: ...@@ -47,6 +49,7 @@ if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2" _REQUIRED_XLA_VERSION = "2.2"
...@@ -54,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29" ...@@ -54,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_NPU_ATTN = is_torch_npu_available()
...@@ -78,6 +82,12 @@ else: ...@@ -78,6 +82,12 @@ else:
flash_attn_3_func = None flash_attn_3_func = None
flash_attn_3_varlen_func = None flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS: if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available(): if not is_kernels_available():
raise ImportError( raise ImportError(
...@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum): ...@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub" _FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# `aiter`
AITER = "aiter"
# PyTorch native # PyTorch native
FLEX = "flex" FLEX = "flex"
NATIVE = "native" NATIVE = "native"
...@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ...@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
) )
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [ elif backend in [
AttentionBackendName.SAGE, AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN, AttentionBackendName.SAGE_VARLEN,
...@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3( ...@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.FLEX, AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
......
...@@ -64,6 +64,8 @@ from .import_utils import ( ...@@ -64,6 +64,8 @@ from .import_utils import (
get_objects_from_module, get_objects_from_module,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_better_profanity_available, is_better_profanity_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_bitsandbytes_version, is_bitsandbytes_version,
......
...@@ -226,6 +226,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available(" ...@@ -226,6 +226,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention") _sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia") _kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
...@@ -406,6 +407,10 @@ def is_flash_attn_3_available(): ...@@ -406,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available return _flash_attn_3_available
def is_aiter_available():
return _aiter_available
def is_kornia_available(): def is_kornia_available():
return _kornia_available return _kornia_available
...@@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str): ...@@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version) return compare_versions(parse(_flash_attn_version), operation, version)
@cache
def is_aiter_version(operation: str, version: str):
"""
Compares the current aiter version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _aiter_available:
return False
return compare_versions(parse(_aiter_version), operation, version)
def get_objects_from_module(module): def get_objects_from_module(module):
""" """
Returns a dict of object names and values in a module, while skipping private/internal objects Returns a dict of object names and values in a module, while skipping private/internal objects
......
...@@ -14,6 +14,10 @@ pytest tests/others/test_attention_backends.py ...@@ -14,6 +14,10 @@ pytest tests/others/test_attention_backends.py
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
""" """
import os import os
...@@ -44,6 +48,10 @@ FORWARD_CASES = [ ...@@ -44,6 +48,10 @@ FORWARD_CASES = [
"_native_cudnn", "_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
), ),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
] ]
COMPILE_CASES = [ COMPILE_CASES = [
...@@ -63,6 +71,11 @@ COMPILE_CASES = [ ...@@ -63,6 +71,11 @@ COMPILE_CASES = [
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True, True,
), ),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
] ]
# fmt: on # fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment