Unverified Commit 56e0b351 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for bottom-right-diagonal causal mask (#960)



* update to FE 1.5.1 and add bottom right causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust logic for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.5.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add get_attention_backend function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak get_attention_backend and fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for unfused, get_backend, etc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cpu offload
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* explicitly skip FP32 and padding tests because there is no support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for window size check
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update check_set_window_size and add enc_dec_attn_mask_type/enc_dec_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f9dd37f7
...@@ -345,10 +345,10 @@ ...@@ -345,10 +345,10 @@
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n", "| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n", "| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
"\n", "\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [_get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n", "<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n" "</div>\n"
] ]
}, },
......
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
import math import math
import os import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union, Optional
import pytest import pytest
import torch import torch
...@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import ( ...@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import (
DotProductAttention, DotProductAttention,
MultiheadAttention, MultiheadAttention,
RotaryPositionEmbedding, RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
) )
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -99,104 +102,96 @@ class ModelConfig: ...@@ -99,104 +102,96 @@ class ModelConfig:
self.bias_shape = bias_shape self.bias_shape = bias_shape
def _is_fused_attention_supported( def _get_attention_backends(
config: ModelConfig, config: ModelConfig,
dtype: torch.dtype, qkv_dtype: torch.dtype,
qkv_layout: str = "sbh3d", qkv_layout: str,
) -> Tuple[bool, NVTE_Fused_Attn_Backend]: window_size: Tuple[int, int] = (-1, -1),
"""Check if FusedAttention supports a model configuration""" pad_between_seqs: bool = False,
backends = [] context_parallel: bool = False,
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" deterministic: bool = False,
backend = tex.get_fused_attn_backend( fp8: bool = False,
TE_DType[dtype], fp8_meta: Optional[Dict[str, Any]] = None,
TE_DType[dtype], ) -> Tuple[List, List]:
QKVLayout[qkv_layout], """Check if what attention backends support a model configuration"""
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
def _is_flash_attention_supported(config: ModelConfig) -> bool: alibi_slopes_shape = None
"""Check if FlashAttention supports a model configuration""" if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if get_device_compute_capability() < (8, 0): if config.bias_shape == "1hss":
return False alibi_slopes_shape = [config.num_heads]
if config.attn_bias_type not in ["no_bias", "alibi"]: if config.bias_shape == "bhss":
return False alibi_slopes_shape = [config.batch_size, config.num_heads]
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
core_attention_bias_requires_grad = True
def _is_unfused_attention_supported( fused_attn_backends = []
config: ModelConfig, available_backends = None
qkv_format: str, os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
) -> bool: _, _, _, available_backends, fused_attention_backend = get_attention_backend(
"""Check if UnfusedDotProductAttention supports a model configuration""" qkv_dtype=qkv_dtype,
if "padding" in config.attn_mask_type: qkv_layout=qkv_layout,
return False batch_size=config.batch_size,
if "causal" in config.attn_mask_type and config.attn_type == "cross": num_heads=config.num_heads,
return False num_gqa_groups=config.num_gqa_groups,
if qkv_format == "thd": max_seqlen_q=config.max_seqlen_q,
return False max_seqlen_kv=config.max_seqlen_kv,
return True head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
model_configs_base = { model_configs_base = {
...@@ -255,25 +250,21 @@ def test_dot_product_attention( ...@@ -255,25 +250,21 @@ def test_dot_product_attention(
if "3" in qkv_layout and config.attn_type == "cross": if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention") pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported # Test backend availability
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) window_size = (2, 2) if swa else (-1, -1)
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) available_backends, fused_attn_backends = _get_attention_backends(
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=window_size,
pad_between_seqs=pad_between_seqs,
) )
if swa: flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config) # Skip if only unfused backend is supported
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = config.head_dim <= 128 is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
...@@ -296,7 +287,7 @@ def test_dot_product_attention( ...@@ -296,7 +287,7 @@ def test_dot_product_attention(
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backend) == 1: if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
...@@ -308,7 +299,7 @@ def test_dot_product_attention( ...@@ -308,7 +299,7 @@ def test_dot_product_attention(
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
if len(fused_attn_backend) == 2: if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
...@@ -363,7 +354,7 @@ def test_dot_product_attention( ...@@ -363,7 +354,7 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2: if fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1") logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd): for i, _ in enumerate(fused_attn_bwd):
...@@ -393,6 +384,14 @@ model_configs_mask = { ...@@ -393,6 +384,14 @@ model_configs_mask = {
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
} }
...@@ -508,7 +507,7 @@ model_configs_swa = { ...@@ -508,7 +507,7 @@ model_configs_swa = {
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
...@@ -530,7 +529,7 @@ model_configs_alibi_slopes = { ...@@ -530,7 +529,7 @@ model_configs_alibi_slopes = {
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
...@@ -994,17 +993,16 @@ def test_transformer_layer( ...@@ -994,17 +993,16 @@ def test_transformer_layer(
tols = dict(atol=5e-1, rtol=5e-2) tols = dict(atol=5e-1, rtol=5e-2)
workspace_opt = True workspace_opt = True
# Skip if only unfused backend is supported # Test backend availability
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: available_backends, fused_attn_backends = _get_attention_backends(
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
) )
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
import os import os
import pytest import pytest
import subprocess import subprocess
from test_fused_attn import ( from test_fused_attn import ModelConfig
ModelConfig, from transformer_engine.pytorch.attention import (
_is_flash_attention_2_available, _flash_attn_2_plus,
_flash_attn_2_3_plus,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
...@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs): ...@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs):
return args return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
...@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D, ffn_hidden_size=4 * D,
num_attention_heads=H, num_attention_heads=H,
attn_input_format=input_format, attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
...@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format, qkv_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
......
...@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding", self_attn_mask_type="causal",
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
) )
......
...@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true; flag_m512 = true;
} }
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || if ( // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) && (cudnn_runtime_version >= 90000)) &&
// number of heads
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) && (cudnn_runtime_version >= 8907)) &&
((head_dim <= 128 && head_dim % 8 == 0) // head dimension
((head_dim <= 128 && head_dim % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend // TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward // d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) && head_dim % 8 == 0)) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) && ((cudnn_runtime_version >= 8906) &&
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
...@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) && ((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
// mask type
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
((cudnn_runtime_version >= 8906) && ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 && (!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) || qkv_format == NVTE_QKV_Format::NVTE_THD) ||
......
...@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
...@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("flash_attention") .set_name("flash_attention")
.set_is_inference(false) .set_is_inference(false)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
...@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
...@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options = fe::graph::SDPA_backward_attributes() sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward") .set_name("flash_attention_backward")
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
sdpa_backward_options.set_alibi_mask(is_alibi); sdpa_backward_options.set_alibi_mask(is_alibi);
......
...@@ -93,10 +93,14 @@ enum NVTE_Mask_Type { ...@@ -93,10 +93,14 @@ enum NVTE_Mask_Type {
NVTE_NO_MASK = 0, NVTE_NO_MASK = 0,
/*! Padding attention mask */ /*! Padding attention mask */
NVTE_PADDING_MASK = 1, NVTE_PADDING_MASK = 1,
/*! Causal attention mask */ /*! Causal attention mask (aligned to the top left corner) */
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */ /*! Padding and causal attention mask (aligned to the top left corner) */
NVTE_PADDING_CAUSAL_MASK = 3, NVTE_PADDING_CAUSAL_MASK = 3,
/*! Causal attention mask (aligned to the bottom right corner) */
NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4,
/*! Padding and causal attention mask (aligned to the bottom right corner) */
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
}; };
/*! \enum NVTE_Fused_Attn_Backend /*! \enum NVTE_Fused_Attn_Backend
......
...@@ -73,6 +73,7 @@ from transformer_engine.pytorch.graph import is_graph_capturing ...@@ -73,6 +73,7 @@ from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6") _flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8") _flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
...@@ -116,6 +117,422 @@ _alibi_cache = { ...@@ -116,6 +117,422 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def get_attention_backend(
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor,
qkv_dtype: torch.dtype = torch.bfloat16,
qkv_layout: str = "sbh3d",
batch_size: int = 1,
num_heads: int = 16,
num_gqa_groups: int = 16,
max_seqlen_q: int = 128,
max_seqlen_kv: int = 128,
head_dim: int = 64,
attn_mask_type: str = "no_mask",
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes_shape: Optional[Union[torch.Size, List]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias_shape: str = "1hss",
core_attention_bias_requires_grad: bool = True,
pad_between_seqs: bool = False,
attention_dropout: float = 0.0,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
):
"""
Select an attention backend based on the user input and runtime environment.
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
Batch size.
num_heads: int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors.
head_dim: int, default = 64
The size of each attention head.
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = (-1, -1)
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
Returns
----------
use_flash_attention: bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
Whether the `FusedAttention` backend has been selected.
use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, the `FusedAttention` sub-backend, else `None`.
available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
logger = logging.getLogger("DotProductAttention")
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
# Filter: ONNX mode
if is_in_onnx_export_mode():
if use_flash_attention:
logger.debug("Disabling FlashAttention due to ONNX mode")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention due to ONNX mode")
use_fused_attention = False
# Filter: Compute capability
device_compute_capability = get_device_compute_capability()
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
# Filter: Context parallelism
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
# Filter: Data type
if use_flash_attention and (
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
):
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_type = %s, qkv_dtype = %s.",
qkv_type,
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention:
logger.debug("Disabling FlashAttention as it does not support FP8")
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_unfused_attention = False
# Filter: Head dimension
if use_flash_attention and (
head_dim > 256
or head_dim % 8 != 0
or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0)))
):
logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). "
"Found: head_dim = %s on sm%s.",
head_dim,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
# Filter: Attention mask
# attn_mask_type | supported backends
# -------------------------------------------------------------------
# no_mask | All
# padding | FlashAttention, FusedAttention
# causal |
# self-attention | All
# cross-attention | FusedAttention
# padding_causal |
# self-attention | FlashAttention, FusedAttention
# cross-attention | FusedAttention
# causal_bottom_right | All
# padding_causal_bottom_right | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
if attn_mask_type == "arbitrary":
if use_flash_attention:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if use_unfused_attention and attn_mask_type == "causal" and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling UnfusedDotProductAttention for "
"top-left-diagonal causal masks for cross-attention"
)
use_unfused_attention = False
if (
use_flash_attention
and _flash_attn_2_1_plus
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and not _flash_attn_2_1_plus
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
# Filter: Sliding window attention
if window_size is not None and window_size[0] != -1 and window_size[1] not in [-1, 0]:
if use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as "
"it does not support sliding window attention"
)
use_unfused_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it does not support sliding window attention")
use_fused_attention = False
if use_flash_attention and (not _flash_attn_2_3_plus or context_parallel):
logger.debug(
"Disabling FlashAttention as sliding window attention requires "
"flash-attn 2.3+ and no context parallelism"
)
use_flash_attention = False
# Filter: Attention bias
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias_shape = core_attention_bias_shape
fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and alibi_slopes_shape is not None
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if (
len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads
):
fu_core_attention_bias_shape = "bhss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
if fu_core_attention_bias_requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
# Filter: cuDNN support
fused_attention_backend = None
if use_fused_attention:
q_type = TE_DType[qkv_dtype]
kv_type = q_type
if fp8 and fp8_meta["recipe"].fp8_dpa:
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"] or (
context_parallel and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
elif (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
logger.debug(
"Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
" [1, H, S, S] shape"
)
use_fused_attention = False
# Filter: Determinism
# backend | deterministic
# ---------------------------------------------
# FlashAttention |
# flash-attn >=2.0, <2.4.1 | no
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if (
use_fused_attention
and (
fused_attention_backend == FusedAttnBackend["FP8"]
or (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and device_compute_capability < (9, 0)
)
)
and deterministic
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# Select FusedAttention for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability == (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
use_fused_attention = False
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
if use_flash_attention:
selected_backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
else:
selected_backend = "NoBackend"
logger.debug(
"Available backends: FlashAttention=%s, FusedAttention=%s, UnfusedDotProductAttention=%s",
bool(available_backends[0]),
bool(available_backends[1]),
bool(available_backends[2]),
)
logger.debug("Selected backend: %s", selected_backend)
return (
use_flash_attention,
use_fused_attention,
use_unfused_attention,
available_backends,
fused_attention_backend,
)
class InferenceParams: # pylint: disable=too-few-public-methods class InferenceParams: # pylint: disable=too-few-public-methods
""" """
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
...@@ -2071,7 +2488,6 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -2071,7 +2488,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
...@@ -2259,7 +2675,7 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -2259,7 +2675,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv return dq, dk, dv
def _get_qkv_layout( def get_qkv_layout(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
...@@ -2382,19 +2798,47 @@ def check_set_window_size( ...@@ -2382,19 +2798,47 @@ def check_set_window_size(
attn_mask_type: str, attn_mask_type: str,
window_size: Tuple[int, int] = None, window_size: Tuple[int, int] = None,
): ):
"""Check if sliding window size is compliant with mask type and if not, """Check if sliding window size is compliant with attention mask type.
assert or set it to the appropriate size If not, set it to the appropriate size.
attn_mask_type | window_size
-------------------------------------------------------------------------
no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)
""" """
orig_window_size = window_size
if "causal" in attn_mask_type: if "causal" in attn_mask_type:
if window_size is None: if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
window_size = (-1, 0) window_size = (-1, 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] >= 0:
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
else: else:
assert ( assert False, (
window_size[1] == 0 "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!" )
else: elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if window_size is None: if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
window_size = (-1, -1) window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[0] < 0:
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
else:
assert False, "Invalid attn_mask_type: " + attn_mask_type
return window_size return window_size
...@@ -4300,24 +4744,35 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4300,24 +4744,35 @@ class DotProductAttention(TransformerEngineBaseModule):
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal` attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`no_mask`", type of attention mask passed into softmax operation, options are "`no_mask`",
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
"`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent. "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
This arg can be overridden by :attr:`attn_mask_type` in the `forward` method. "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
It is useful for cases involving compilation/tracing, e.g. ONNX export, and the are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
forward arg is useful for dynamically changing mask types, e.g. a different mask `forward` method. It is useful for cases involving compilation/tracing, e.g.
for training and inference. For "`no_mask`", no attention mask is applied. For ONNX export, and the forward arg is useful for dynamically changing mask types,
"`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates e.g. a different mask for training and inference.
and applies an upper triangular mask to the softmax input. No user input is 1. For "`no_mask`", no attention mask is applied.
needed. For "`padding`" or the padding mask in "`padding,causal`", users need to 2. For "`causal`", "`causal_bottom_right`", or the causal mask in
provide the locations of padded tokens either via :attr:`cu_seqlens_q` and "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
:attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask` calculates and applies an upper triangular mask to the softmax input.
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users No user input is needed. Causal masks without the "`bottom_right`" appendix align
need to provide a mask that is broadcastable to the shape of softmax input. the diagonal line to the top left corner of the softmax matrix. With
"`bottom_right`", the causal mask is aligned to the bottom right corner, which is
often used in inference/KV caching.
3. For "`padding`", or the padding mask in "`padding_causal`" and
"`padding_causal_bottom_right`", users need to provide the locations of padded
tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
[batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
[batch_size, 1, 1, max_seqlen_kv]).
4. For "`arbitrary`", users need to provide a mask that is broadcastable to
the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well. be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self` attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`". type of attention, either "`self`" and "`cross`".
...@@ -4333,7 +4788,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4333,7 +4788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
equal length, and the `thd` format is used for when sequences in a batch equal length, and the `thd` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`. `1.0 / math.sqrt(kv_channels)`.
...@@ -4385,8 +4840,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4385,8 +4840,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal" attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size self.window_size = check_set_window_size(attn_mask_type, window_size)
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -4620,13 +5074,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4620,13 +5074,13 @@ class DotProductAttention(TransformerEngineBaseModule):
Value tensor. Value tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input. default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
means the corresponding position is masked out and a `False` means that position is the corresponding position is masked out and a `False` means that position
allowed to participate in attention. is allowed to participate in attention.
qkv_format: str, default = `None` qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization. If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
...@@ -4651,9 +5105,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4651,9 +5105,13 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided. Calculated from `cu_seqlens_kv` if not provided.
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`, attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
`arbitrary`}, default = `None`. Type of attention mask passed into 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
softmax operation. 'padding,causal' and 'causal,padding' are equivalent. 'arbitrary'}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
are equivalent. By default, causal masks are aligned to the top left corner
of the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention. Sliding window size for local attention.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
...@@ -4724,18 +5182,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4724,18 +5182,17 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors." ), "DotProductAttention only supports CUDA tensors."
assert (
query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
), "Queries, keys and values must have the same data type!"
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
else: else:
attn_mask_type = attn_mask_type.replace(",", "_") attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal" attn_mask_type = "padding_causal"
assert ( assert (
attn_mask_type in AttnMaskTypes attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!" ), f"Attention mask type {attn_mask_type} is not supported!"
...@@ -4744,6 +5201,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4744,6 +5201,10 @@ class DotProductAttention(TransformerEngineBaseModule):
"padding" in attn_mask_type "padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!" ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance( assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker self.rng_states_tracker, CudaRNGStatesTracker
...@@ -4752,9 +5213,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4752,9 +5213,6 @@ class DotProductAttention(TransformerEngineBaseModule):
graph_safe_rng_available() graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if window_size is None:
window_size = self.window_size
if qkv_format is None: if qkv_format is None:
qkv_format = self.qkv_format qkv_format = self.qkv_format
...@@ -4826,6 +5284,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4826,6 +5284,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if max_seqlen_kv is None: if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
batch_size = len(cu_seqlens_q) - 1
if qkv_format in ["sbhd", "bshd"]: if qkv_format in ["sbhd", "bshd"]:
assert all( assert all(
...@@ -4833,8 +5292,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4833,8 +5292,10 @@ class DotProductAttention(TransformerEngineBaseModule):
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd": if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
batch_size = query_layer.shape[1]
if qkv_format == "bshd": if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0]
if cu_seqlens_q is not None: if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all( assert all(
...@@ -4853,176 +5314,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4853,176 +5314,14 @@ class DotProductAttention(TransformerEngineBaseModule):
and isinstance(key_layer, Float8Tensor) and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor) and isinstance(value_layer, Float8Tensor)
): ):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout( qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
) )
else: else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format query_layer, key_layer, value_layer, qkv_format=qkv_format
) )
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if qkv_format == "thd":
if use_unfused_attention:
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_fused_attention and (
(
cu_seqlens_q_padded is not None
and torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
)
or (
cu_seqlens_kv_padded is not None
and torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
):
self.logger.debug(
"Disabling FlashAttention for qkv_format = thd "
"when there is padding between sequences."
)
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for ONNX mode")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Input type.
if use_flash_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
# Filter: Execution type.
if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
use_flash_attention = False
if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug(
"Disabling UnfusedDotProductAttention as it does not support FP8 execution."
)
use_unfused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# Filter: cross attention + causal mask.
# (in training mode)
if (
use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
use_flash_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
#
if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_unfused_attention
and inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: bias.
global _alibi_cache global _alibi_cache
if alibi_slopes is not None: if alibi_slopes is not None:
assert ( assert (
...@@ -5044,130 +5343,77 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5044,130 +5343,77 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
if use_flash_attention and ( deterministic = (
core_attention_bias_type not in ["no_bias", "alibi"] not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or core_attention_bias is not None or torch.are_deterministic_algorithms_enabled()
): )
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type context_parallel = (
fu_core_attention_bias = core_attention_bias self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
if (
core_attention_bias_type == "alibi"
and use_fused_attention
and alibi_slopes is not None
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
) )
core_attention_bias_shape = None
if core_attention_bias is not None:
if ( if (
use_fused_attention core_attention_bias.shape[0] == batch_size
and fu_core_attention_bias_type == "post_scale_bias" and core_attention_bias.shape[1] == query_layer.shape[-2]
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
): ):
if fu_core_attention_bias.requires_grad: core_attention_bias_shape = "bhss"
# remove this line when cuDNN adds bwd support for elif (
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] core_attention_bias.shape[0] == 1
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") and core_attention_bias.shape[1] == query_layer.shape[-2]
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention:
q_type = TE_DType[query_layer.dtype]
kv_type = TE_DType[key_layer.dtype]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
if isinstance(query_layer, Float8Tensor) and isinstance(
key_layer, Float8Tensor
): ):
q_type = query_layer._fp8_dtype core_attention_bias_shape = "1hss"
kv_type = value_layer._fp8_dtype elif (
else: core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
q_type = forward_dtype
kv_type = forward_dtype
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"],
]
use_fused_attention = (
use_fused_attention
and is_backend_avail
and (
not context_parallel
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
)
)
if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
): ):
self.logger.debug( core_attention_bias_shape = "b1ss"
"Disabling FusedAttention as no backend supports the provided input" elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
) core_attention_bias_shape = "11ss"
use_fused_attention = False else:
assert (
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# Filter: determinism. pad_between_seqs = (
# backend | deterministic cu_seqlens_q_padded is not None
# --------------------------------------------------------- and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
# flash-attn v1 | yes ) or (
# flash-attn v2 | no cu_seqlens_kv_padded is not None
# FusedAttnBackend["F16_max512_seqlen"] | yes and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no )
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance (
if ( use_flash_attention,
use_flash_attention use_fused_attention,
and use_fused_attention use_unfused_attention,
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] _,
): fused_attention_backend,
if self.device_compute_capability == (9, 0): ) = get_attention_backend(
self.logger.debug( qkv_type=type(query_layer),
"Disabling FlashAttention to give FusedAttention preference on Hopper+ " qkv_dtype=query_layer.dtype,
"for performance reasons" qkv_layout=qkv_layout,
batch_size=batch_size,
num_heads=query_layer.shape[-2],
num_gqa_groups=key_layer.shape[-2],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
head_dim=query_layer.shape[-1],
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=(
core_attention_bias.requires_grad if core_attention_bias is not None else False
),
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
) )
use_flash_attention = False
run_config = { run_config = {
"compute_capability": "sm" "compute_capability": "sm"
...@@ -5240,6 +5486,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5240,6 +5486,17 @@ class DotProductAttention(TransformerEngineBaseModule):
int(os.getenv("NVTE_FP8_DPA_BWD", "1")), int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
) )
self.logger.debug("Running with config=%s", run_config) self.logger.debug("Running with config=%s", run_config)
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -5289,10 +5546,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5289,10 +5546,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
assert (
not context_parallel
), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
...@@ -5371,7 +5624,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5371,7 +5624,8 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'}, attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `causal` default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the `forward` method. The forward
...@@ -5382,7 +5636,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5382,7 +5636,7 @@ class MultiheadAttention(torch.nn.Module):
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well. be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None` num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
...@@ -5429,7 +5683,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5429,7 +5683,7 @@ class MultiheadAttention(torch.nn.Module):
are used for when sequences in a batch are of equal length or padded to are used for when sequences in a batch are of equal length or padded to
equal length. Please note that these formats do not reflect how equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -5513,8 +5767,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5513,8 +5767,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size self.window_size = check_set_window_size(attn_mask_type, window_size)
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.layer_number = layer_number self.layer_number = layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
...@@ -5762,16 +6015,20 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5762,16 +6015,20 @@ class MultiheadAttention(torch.nn.Module):
Input tensor. Input tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input. default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
means the corresponding position is masked out and a `False` means that position is the corresponding position is masked out and a `False` means that position
allowed to participate in attention. is allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `None` default = `None`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation. By default,
causal masks are aligned to the top left corner of the softmax matrix.
When "`bottom_right`" is specified in the mask type, causal masks are
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention. sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
...@@ -5812,12 +6069,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5812,12 +6069,11 @@ class MultiheadAttention(torch.nn.Module):
""" """
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
if "padding" in attn_mask_type and attention_mask is not None: if "padding" in attn_mask_type and attention_mask is not None:
for i, _ in enumerate(attention_mask): for i, _ in enumerate(attention_mask):
......
...@@ -22,7 +22,15 @@ TE_DType = { ...@@ -22,7 +22,15 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask") AttnMaskTypes = (
"no_mask",
"padding",
"causal",
"padding_causal",
"causal_bottom_right",
"padding_causal_bottom_right",
"arbitrary",
)
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -64,6 +64,8 @@ AttnMaskType = { ...@@ -64,6 +64,8 @@ AttnMaskType = {
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK, "padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK, "causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK, "padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
"causal_bottom_right": NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
FusedAttnBackend = { FusedAttnBackend = {
......
...@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
......
...@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary": if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True return True
if ( if (
...@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Fused masked softmax kernel""" """Fused masked softmax kernel"""
scale = 1.0 if scale is None else scale scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal": if "causal" in self.attn_mask_type:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (b, np, sq, sk)
...@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None: if scale is not None:
inp = inp * scale inp = inp * scale
if self.attn_mask_type == "causal": if "causal" in self.attn_mask_type:
seq_len_q, seq_len_k = inp.size(2), inp.size(3) seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k assert self.kvcache_max_seq >= seq_len_k
......
...@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module): ...@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of query-key-value channels per attention head. defaults to number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'},
default = `causal` default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation for encoder.
:attr:`self_attn_mask_type` in the `forward` method. The forward Overridden by :attr:`self_attn_mask_type` in the `forward` method.
arg is useful for dynamically changing mask types, e.g. a different The forward arg is useful for dynamically changing mask types, e.g.
mask for training and inference. The init arg is useful for cases a different mask for training and inference. The init arg is useful
involving compilation/tracing, e.g. ONNX export. for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention in encoder, where query at position i
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can no sliding window and "`causal`" mask specifically. Similar to
be overridden by :attr:`window_size` in `forward` as well. :attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention in decoder.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module):
super().__init__() super().__init__()
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size) self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
...@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module):
self.inter_attention = MultiheadAttention( self.inter_attention = MultiheadAttention(
*attention_args, *attention_args,
**common_attention_kwargs, **common_attention_kwargs,
attn_mask_type="padding", attn_mask_type=enc_dec_attn_mask_type,
input_layernorm=True, input_layernorm=True,
attention_type="cross", attention_type="cross",
bias=bias, bias=bias,
...@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module):
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module): ...@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module):
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input. It should be
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask, in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'. mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention. a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal` default = `causal`
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention. Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask. [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'. A `True` value for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
means the corresponding position is masked out and a `False` means that position is A `True` value means the corresponding position is masked out and a `False`
allowed to participate in attention. means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module): ...@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference. to efficienly calculate and store the context during inference.
""" """
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None: if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported" ), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert (
enc_dec_attn_mask_type in AttnMaskTypes
), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module):
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None: ) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None:
assert all(
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=window_size, window_size=enc_dec_window_size,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment