Unverified Commit 56e0b351 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for bottom-right-diagonal causal mask (#960)



* update to FE 1.5.1 and add bottom right causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust logic for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.5.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add get_attention_backend function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak get_attention_backend and fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for unfused, get_backend, etc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cpu offload
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* explicitly skip FP32 and padding tests because there is no support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for window size check
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update check_set_window_size and add enc_dec_attn_mask_type/enc_dec_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f9dd37f7
......@@ -345,10 +345,10 @@
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [_get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n"
]
},
......
......@@ -7,7 +7,7 @@ import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
......@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -99,104 +102,96 @@ class ModelConfig:
self.bias_shape = bias_shape
def _is_fused_attention_supported(
def _get_attention_backends(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout: str = "sbh3d",
) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
"""Check if FusedAttention supports a model configuration"""
backends = []
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
return False
if config.attn_bias_type not in ["no_bias", "alibi"]:
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
def _is_unfused_attention_supported(
config: ModelConfig,
qkv_format: str,
) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if "padding" in config.attn_mask_type:
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
return False
if qkv_format == "thd":
return False
return True
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
model_configs_base = {
......@@ -255,25 +250,21 @@ def test_dot_product_attention(
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
# Test backend availability
window_size = (2, 2) if swa else (-1, -1)
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=window_size,
pad_between_seqs=pad_between_seqs,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
......@@ -296,7 +287,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backend) == 1:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
......@@ -308,7 +299,7 @@ def test_dot_product_attention(
pad_between_seqs,
is_training,
)
if len(fused_attn_backend) == 2:
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
......@@ -363,7 +354,7 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
if fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
......@@ -393,6 +384,14 @@ model_configs_mask = {
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
......@@ -508,7 +507,7 @@ model_configs_swa = {
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
......@@ -530,7 +529,7 @@ model_configs_alibi_slopes = {
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
......@@ -994,17 +993,16 @@ def test_transformer_layer(
tols = dict(atol=5e-1, rtol=5e-2)
workspace_opt = True
# Skip if only unfused backend is supported
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported = _is_flash_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
......
......@@ -5,9 +5,10 @@
import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
from test_fused_attn import ModelConfig
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
......@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs):
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
......@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
......@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
params_dtype=dtype,
)
.cuda()
......
......@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
self_attn_mask_type="causal",
normalization=normalization,
device="cuda",
)
......
......@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true;
}
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
if ( // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) &&
// number of heads
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) &&
((head_dim <= 128 && head_dim % 8 == 0)
// head dimension
((head_dim <= 128 && head_dim % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) &&
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) &&
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
......@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
// mask type
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) &&
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) ||
......
......@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("flash_attention")
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
sdpa_options.set_alibi_mask(is_alibi);
......@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
......@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
sdpa_backward_options.set_alibi_mask(is_alibi);
......
......@@ -93,10 +93,14 @@ enum NVTE_Mask_Type {
NVTE_NO_MASK = 0,
/*! Padding attention mask */
NVTE_PADDING_MASK = 1,
/*! Causal attention mask */
/*! Causal attention mask (aligned to the top left corner) */
NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */
/*! Padding and causal attention mask (aligned to the top left corner) */
NVTE_PADDING_CAUSAL_MASK = 3,
/*! Causal attention mask (aligned to the bottom right corner) */
NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4,
/*! Padding and causal attention mask (aligned to the bottom right corner) */
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Fused_Attn_Backend
......
......@@ -73,6 +73,7 @@ from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
......@@ -116,6 +117,422 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def get_attention_backend(
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor,
qkv_dtype: torch.dtype = torch.bfloat16,
qkv_layout: str = "sbh3d",
batch_size: int = 1,
num_heads: int = 16,
num_gqa_groups: int = 16,
max_seqlen_q: int = 128,
max_seqlen_kv: int = 128,
head_dim: int = 64,
attn_mask_type: str = "no_mask",
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes_shape: Optional[Union[torch.Size, List]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias_shape: str = "1hss",
core_attention_bias_requires_grad: bool = True,
pad_between_seqs: bool = False,
attention_dropout: float = 0.0,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
):
"""
Select an attention backend based on the user input and runtime environment.
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
Batch size.
num_heads: int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors.
head_dim: int, default = 64
The size of each attention head.
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = (-1, -1)
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
Returns
----------
use_flash_attention: bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
Whether the `FusedAttention` backend has been selected.
use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, the `FusedAttention` sub-backend, else `None`.
available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
logger = logging.getLogger("DotProductAttention")
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
# Filter: ONNX mode
if is_in_onnx_export_mode():
if use_flash_attention:
logger.debug("Disabling FlashAttention due to ONNX mode")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention due to ONNX mode")
use_fused_attention = False
# Filter: Compute capability
device_compute_capability = get_device_compute_capability()
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
# Filter: Context parallelism
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
# Filter: Data type
if use_flash_attention and (
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
):
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_type = %s, qkv_dtype = %s.",
qkv_type,
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention:
logger.debug("Disabling FlashAttention as it does not support FP8")
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_unfused_attention = False
# Filter: Head dimension
if use_flash_attention and (
head_dim > 256
or head_dim % 8 != 0
or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0)))
):
logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). "
"Found: head_dim = %s on sm%s.",
head_dim,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
# Filter: Attention mask
# attn_mask_type | supported backends
# -------------------------------------------------------------------
# no_mask | All
# padding | FlashAttention, FusedAttention
# causal |
# self-attention | All
# cross-attention | FusedAttention
# padding_causal |
# self-attention | FlashAttention, FusedAttention
# cross-attention | FusedAttention
# causal_bottom_right | All
# padding_causal_bottom_right | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
if attn_mask_type == "arbitrary":
if use_flash_attention:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if use_unfused_attention and attn_mask_type == "causal" and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling UnfusedDotProductAttention for "
"top-left-diagonal causal masks for cross-attention"
)
use_unfused_attention = False
if (
use_flash_attention
and _flash_attn_2_1_plus
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and not _flash_attn_2_1_plus
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
# Filter: Sliding window attention
if window_size is not None and window_size[0] != -1 and window_size[1] not in [-1, 0]:
if use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as "
"it does not support sliding window attention"
)
use_unfused_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it does not support sliding window attention")
use_fused_attention = False
if use_flash_attention and (not _flash_attn_2_3_plus or context_parallel):
logger.debug(
"Disabling FlashAttention as sliding window attention requires "
"flash-attn 2.3+ and no context parallelism"
)
use_flash_attention = False
# Filter: Attention bias
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias_shape = core_attention_bias_shape
fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and alibi_slopes_shape is not None
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if (
len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads
):
fu_core_attention_bias_shape = "bhss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
if fu_core_attention_bias_requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
# Filter: cuDNN support
fused_attention_backend = None
if use_fused_attention:
q_type = TE_DType[qkv_dtype]
kv_type = q_type
if fp8 and fp8_meta["recipe"].fp8_dpa:
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"] or (
context_parallel and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
elif (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
logger.debug(
"Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
" [1, H, S, S] shape"
)
use_fused_attention = False
# Filter: Determinism
# backend | deterministic
# ---------------------------------------------
# FlashAttention |
# flash-attn >=2.0, <2.4.1 | no
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if (
use_fused_attention
and (
fused_attention_backend == FusedAttnBackend["FP8"]
or (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and device_compute_capability < (9, 0)
)
)
and deterministic
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# Select FusedAttention for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability == (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
use_fused_attention = False
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
if use_flash_attention:
selected_backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
else:
selected_backend = "NoBackend"
logger.debug(
"Available backends: FlashAttention=%s, FusedAttention=%s, UnfusedDotProductAttention=%s",
bool(available_backends[0]),
bool(available_backends[1]),
bool(available_backends[2]),
)
logger.debug("Selected backend: %s", selected_backend)
return (
use_flash_attention,
use_fused_attention,
use_unfused_attention,
available_backends,
fused_attention_backend,
)
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
......@@ -2071,7 +2488,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
......@@ -2259,7 +2675,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv
def _get_qkv_layout(
def get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
......@@ -2382,19 +2798,47 @@ def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
):
"""Check if sliding window size is compliant with mask type and if not,
assert or set it to the appropriate size
"""Check if sliding window size is compliant with attention mask type.
If not, set it to the appropriate size.
attn_mask_type | window_size
-------------------------------------------------------------------------
no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if window_size is None:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
window_size = (-1, 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] >= 0:
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
else:
assert (
window_size[1] == 0
), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
else:
if window_size is None:
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[0] < 0:
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
else:
assert False, "Invalid attn_mask_type: " + attn_mask_type
return window_size
......@@ -4300,24 +4744,35 @@ class DotProductAttention(TransformerEngineBaseModule):
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`no_mask`",
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
"`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
forward arg is useful for dynamically changing mask types, e.g. a different mask
for training and inference. For "`no_mask`", no attention mask is applied. For
"`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
and applies an upper triangular mask to the softmax input. No user input is
needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
:attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
need to provide a mask that is broadcastable to the shape of softmax input.
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
"`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
"`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
`forward` method. It is useful for cases involving compilation/tracing, e.g.
ONNX export, and the forward arg is useful for dynamically changing mask types,
e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in
"`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
calculates and applies an upper triangular mask to the softmax input.
No user input is needed. Causal masks without the "`bottom_right`" appendix align
the diagonal line to the top left corner of the softmax matrix. With
"`bottom_right`", the causal mask is aligned to the bottom right corner, which is
often used in inference/KV caching.
3. For "`padding`", or the padding mask in "`padding_causal`" and
"`padding_causal_bottom_right`", users need to provide the locations of padded
tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
[batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
[batch_size, 1, 1, max_seqlen_kv]).
4. For "`arbitrary`", users need to provide a mask that is broadcastable to
the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
......@@ -4333,7 +4788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
equal length, and the `thd` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`.
......@@ -4385,8 +4840,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.window_size = check_set_window_size(attn_mask_type, window_size)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -4620,13 +5074,13 @@ class DotProductAttention(TransformerEngineBaseModule):
Value tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
the corresponding position is masked out and a `False` means that position
is allowed to participate in attention.
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
......@@ -4651,9 +5105,13 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
`arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
'arbitrary'}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
are equivalent. By default, causal masks are aligned to the top left corner
of the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention.
checkpoint_core_attention : bool, default = `False`
......@@ -4724,18 +5182,17 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors."
assert (
query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
), "Queries, keys and values must have the same data type!"
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
assert (
attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
......@@ -4744,6 +5201,10 @@ class DotProductAttention(TransformerEngineBaseModule):
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
......@@ -4752,9 +5213,6 @@ class DotProductAttention(TransformerEngineBaseModule):
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if window_size is None:
window_size = self.window_size
if qkv_format is None:
qkv_format = self.qkv_format
......@@ -4826,6 +5284,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
batch_size = len(cu_seqlens_q) - 1
if qkv_format in ["sbhd", "bshd"]:
assert all(
......@@ -4833,8 +5292,10 @@ class DotProductAttention(TransformerEngineBaseModule):
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
batch_size = query_layer.shape[1]
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0]
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
......@@ -4853,176 +5314,14 @@ class DotProductAttention(TransformerEngineBaseModule):
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if qkv_format == "thd":
if use_unfused_attention:
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_fused_attention and (
(
cu_seqlens_q_padded is not None
and torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
)
or (
cu_seqlens_kv_padded is not None
and torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
):
self.logger.debug(
"Disabling FlashAttention for qkv_format = thd "
"when there is padding between sequences."
)
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for ONNX mode")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Input type.
if use_flash_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
# Filter: Execution type.
if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
use_flash_attention = False
if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug(
"Disabling UnfusedDotProductAttention as it does not support FP8 execution."
)
use_unfused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# Filter: cross attention + causal mask.
# (in training mode)
if (
use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
use_flash_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
#
if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_unfused_attention
and inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: bias.
global _alibi_cache
if alibi_slopes is not None:
assert (
......@@ -5044,130 +5343,77 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias is not None
):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if (
core_attention_bias_type == "alibi"
and use_fused_attention
and alibi_slopes is not None
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
):
if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
if use_fused_attention:
q_type = TE_DType[query_layer.dtype]
kv_type = TE_DType[key_layer.dtype]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
if isinstance(query_layer, Float8Tensor) and isinstance(
key_layer, Float8Tensor
):
q_type = query_layer._fp8_dtype
kv_type = value_layer._fp8_dtype
else:
q_type = forward_dtype
kv_type = forward_dtype
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"],
]
use_fused_attention = (
use_fused_attention
and is_backend_avail
and (
not context_parallel
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
)
)
core_attention_bias_shape = None
if core_attention_bias is not None:
if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
core_attention_bias.shape[0] == batch_size
and core_attention_bias.shape[1] == query_layer.shape[-2]
):
self.logger.debug(
"Disabling FusedAttention as no backend supports the provided input"
)
use_fused_attention = False
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
core_attention_bias_shape = "bhss"
elif (
core_attention_bias.shape[0] == 1
and core_attention_bias.shape[1] == query_layer.shape[-2]
):
core_attention_bias_shape = "1hss"
elif (
core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
):
core_attention_bias_shape = "b1ss"
elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
core_attention_bias_shape = "11ss"
else:
assert (
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
# Select FusedAttention on sm90 and FlashAttention on others for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if self.device_compute_capability == (9, 0):
self.logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
(
use_flash_attention,
use_fused_attention,
use_unfused_attention,
_,
fused_attention_backend,
) = get_attention_backend(
qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout,
batch_size=batch_size,
num_heads=query_layer.shape[-2],
num_gqa_groups=key_layer.shape[-2],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
head_dim=query_layer.shape[-1],
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=(
core_attention_bias.requires_grad if core_attention_bias is not None else False
),
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
run_config = {
"compute_capability": "sm"
......@@ -5240,6 +5486,17 @@ class DotProductAttention(TransformerEngineBaseModule):
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
self.logger.debug("Running with config=%s", run_config)
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -5289,10 +5546,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta,
)
assert (
not context_parallel
), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
......@@ -5371,7 +5624,8 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
......@@ -5382,7 +5636,7 @@ class MultiheadAttention(torch.nn.Module):
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
......@@ -5429,7 +5683,7 @@ class MultiheadAttention(torch.nn.Module):
are used for when sequences in a batch are of equal length or padded to
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
For that, please use `get_qkv_layout` to gain the layout information.
Parallelism parameters
----------------------
......@@ -5513,8 +5767,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.window_size = check_set_window_size(attn_mask_type, window_size)
self.layer_number = layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -5762,16 +6015,20 @@ class MultiheadAttention(torch.nn.Module):
Input tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
the corresponding position is masked out and a `False` means that position
is allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `None`
type of attention mask passed into softmax operation.
type of attention mask passed into softmax operation. By default,
causal masks are aligned to the top left corner of the softmax matrix.
When "`bottom_right`" is specified in the mask type, causal masks are
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None`
......@@ -5812,12 +6069,11 @@ class MultiheadAttention(torch.nn.Module):
"""
# hidden_states: [sq, b, h]
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(attn_mask_type, window_size)
if "padding" in attn_mask_type and attention_mask is not None:
for i, _ in enumerate(attention_mask):
......
......@@ -22,7 +22,15 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask")
AttnMaskTypes = (
"no_mask",
"padding",
"causal",
"padding_causal",
"causal_bottom_right",
"padding_causal_bottom_right",
"arbitrary",
)
AttnTypes = ("self", "cross")
......
......@@ -64,6 +64,8 @@ AttnMaskType = {
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
"causal_bottom_right": NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
FusedAttnBackend = {
......
......@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
......
......@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel
if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True
if (
......@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Fused masked softmax kernel"""
scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal":
if "causal" in self.attn_mask_type:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk)
......@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None:
inp = inp * scale
if self.attn_mask_type == "causal":
if "causal" in self.attn_mask_type:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k
......
......@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
type of attention mask passed into softmax operation for encoder.
Overridden by :attr:`self_attn_mask_type` in the `forward` method.
The forward arg is useful for dynamically changing mask types, e.g.
a different mask for training and inference. The init arg is useful
for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
no sliding window and "`causal`" mask specifically. Similar to
:attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention in decoder.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
......@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module):
super().__init__()
self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
......@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module):
self.inter_attention = MultiheadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type="padding",
attn_mask_type=enc_dec_attn_mask_type,
input_layernorm=True,
attention_type="cross",
bias=bias,
......@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module):
window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module):
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference.
"""
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert (
enc_dec_attn_mask_type in AttnMaskTypes
), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous()
......@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module):
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None:
assert all(
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
# For AMP
if torch.is_autocast_enabled():
......@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=window_size,
window_size=enc_dec_window_size,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment