Unverified Commit 56e0b351 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for bottom-right-diagonal causal mask (#960)



* update to FE 1.5.1 and add bottom right causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust logic for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.5.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add get_attention_backend function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak get_attention_backend and fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for unfused, get_backend, etc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cpu offload
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* explicitly skip FP32 and padding tests because there is no support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for window size check
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update check_set_window_size and add enc_dec_attn_mask_type/enc_dec_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f9dd37f7
......@@ -345,10 +345,10 @@
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [_get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n"
]
},
......
......@@ -7,7 +7,7 @@ import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
......@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -99,104 +102,96 @@ class ModelConfig:
self.bias_shape = bias_shape
def _is_fused_attention_supported(
def _get_attention_backends(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout: str = "sbh3d",
) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
"""Check if FusedAttention supports a model configuration"""
backends = []
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
return False
if config.attn_bias_type not in ["no_bias", "alibi"]:
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
core_attention_bias_requires_grad = True
def _is_unfused_attention_supported(
config: ModelConfig,
qkv_format: str,
) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if "padding" in config.attn_mask_type:
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
return False
if qkv_format == "thd":
return False
return True
fused_attn_backends = []
available_backends = None
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
model_configs_base = {
......@@ -255,25 +250,21 @@ def test_dot_product_attention(
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
# Test backend availability
window_size = (2, 2) if swa else (-1, -1)
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=window_size,
pad_between_seqs=pad_between_seqs,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
......@@ -296,7 +287,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backend) == 1:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
......@@ -308,7 +299,7 @@ def test_dot_product_attention(
pad_between_seqs,
is_training,
)
if len(fused_attn_backend) == 2:
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
......@@ -363,7 +354,7 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
if fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
......@@ -393,6 +384,14 @@ model_configs_mask = {
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
......@@ -508,7 +507,7 @@ model_configs_swa = {
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
......@@ -530,7 +529,7 @@ model_configs_alibi_slopes = {
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
......@@ -994,17 +993,16 @@ def test_transformer_layer(
tols = dict(atol=5e-1, rtol=5e-2)
workspace_opt = True
# Skip if only unfused backend is supported
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported = _is_flash_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
......
......@@ -5,9 +5,10 @@
import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
from test_fused_attn import ModelConfig
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
......@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs):
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
......@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
......@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
params_dtype=dtype,
)
.cuda()
......
......@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
self_attn_mask_type="causal",
normalization=normalization,
device="cuda",
)
......
......@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true;
}
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
if ( // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) &&
// number of heads
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) &&
((head_dim <= 128 && head_dim % 8 == 0)
// head dimension
((head_dim <= 128 && head_dim % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) &&
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
......@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
// mask type
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) &&
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) ||
......
......@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("flash_attention")
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
sdpa_options.set_alibi_mask(is_alibi);
......@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
......@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
sdpa_backward_options.set_alibi_mask(is_alibi);
......
......@@ -93,10 +93,14 @@ enum NVTE_Mask_Type {
NVTE_NO_MASK = 0,
/*! Padding attention mask */
NVTE_PADDING_MASK = 1,
/*! Causal attention mask */
/*! Causal attention mask (aligned to the top left corner) */
NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */
/*! Padding and causal attention mask (aligned to the top left corner) */
NVTE_PADDING_CAUSAL_MASK = 3,
/*! Causal attention mask (aligned to the bottom right corner) */
NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4,
/*! Padding and causal attention mask (aligned to the bottom right corner) */
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Fused_Attn_Backend
......
This diff is collapsed.
......@@ -22,7 +22,15 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask")
AttnMaskTypes = (
"no_mask",
"padding",
"causal",
"padding_causal",
"causal_bottom_right",
"padding_causal_bottom_right",
"arbitrary",
)
AttnTypes = ("self", "cross")
......
......@@ -64,6 +64,8 @@ AttnMaskType = {
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
"causal_bottom_right": NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
FusedAttnBackend = {
......
......@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
......
......@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel
if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True
if (
......@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Fused masked softmax kernel"""
scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal":
if "causal" in self.attn_mask_type:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk)
......@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None:
inp = inp * scale
if self.attn_mask_type == "causal":
if "causal" in self.attn_mask_type:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k
......
......@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
type of attention mask passed into softmax operation for encoder.
Overridden by :attr:`self_attn_mask_type` in the `forward` method.
The forward arg is useful for dynamically changing mask types, e.g.
a different mask for training and inference. The init arg is useful
for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
no sliding window and "`causal`" mask specifically. Similar to
:attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention in decoder.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
......@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module):
super().__init__()
self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
......@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module):
self.inter_attention = MultiheadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type="padding",
attn_mask_type=enc_dec_attn_mask_type,
input_layernorm=True,
attention_type="cross",
bias=bias,
......@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module):
window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module):
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference.
"""
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert (
enc_dec_attn_mask_type in AttnMaskTypes
), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous()
......@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module):
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None:
assert all(
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
# For AMP
if torch.is_autocast_enabled():
......@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=window_size,
window_size=enc_dec_window_size,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment