Unverified Commit 8e039fdc authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add cuDNN sliding window and set_deterministic_algorithm (#992)



* add cuDNN swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix SWA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add set_deterministic and minor fixes for swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add AttentionParams
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change window_size to int64_t; fix swa/determinism tests; cache _attention_backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add window_size to get_backend; fix jax and paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; add set_deter to bwd_impl
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 tests due to determinism
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support matrix for SWA and bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add wording on window_size special cases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak on wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax assertion error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* call bwd with deterministic=true for jax/paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism words in documentation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 166bb078
......@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "141fa8bd",
"id": "8ae3bc43",
"metadata": {},
"source": [
"# Attention Is All You Need!\n",
......@@ -22,7 +22,7 @@
},
{
"cell_type": "markdown",
"id": "09a60057",
"id": "47421c01",
"metadata": {},
"source": [
"## 1. Attention Backends\n",
......@@ -38,7 +38,7 @@
},
{
"cell_type": "markdown",
"id": "f387274e",
"id": "e52f60f0",
"metadata": {},
"source": [
"### 1.1 Flash vs. Non-Flash\n",
......@@ -58,7 +58,7 @@
},
{
"cell_type": "markdown",
"id": "f1389145",
"id": "bb909ac4",
"metadata": {},
"source": [
"### 1.2 flash-attention\n",
......@@ -97,7 +97,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "bbc5c73f",
"id": "9a380859",
"metadata": {},
"outputs": [],
"source": [
......@@ -113,7 +113,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "173638b6",
"id": "0584bb01",
"metadata": {},
"outputs": [
{
......@@ -140,7 +140,7 @@
},
{
"cell_type": "markdown",
"id": "0f62d2fa",
"id": "45e53fc9",
"metadata": {},
"source": [
"## 2. Backend Selection\n",
......@@ -160,7 +160,7 @@
},
{
"cell_type": "markdown",
"id": "86e16a2b",
"id": "6dfeade3",
"metadata": {},
"source": [
"### 2.1 Debug Information\n",
......@@ -177,7 +177,7 @@
},
{
"cell_type": "markdown",
"id": "e439434e",
"id": "7e3b7981",
"metadata": {},
"source": [
"The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
......@@ -186,7 +186,7 @@
{
"cell_type": "code",
"execution_count": 22,
"id": "9d002327",
"id": "961c51d4",
"metadata": {},
"outputs": [
{
......@@ -210,7 +210,7 @@
},
{
"cell_type": "markdown",
"id": "bbf1756c",
"id": "11bfbbd7",
"metadata": {},
"source": [
"To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
......@@ -219,7 +219,7 @@
{
"cell_type": "code",
"execution_count": 25,
"id": "66a2f34c",
"id": "162a2be1",
"metadata": {},
"outputs": [
{
......@@ -249,7 +249,7 @@
},
{
"cell_type": "markdown",
"id": "9f964732",
"id": "779a51e6",
"metadata": {},
"source": [
"### 2.2 User Control\n",
......@@ -274,11 +274,16 @@
"\n",
"Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
"```\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"Before cuDNN 9.0:\n",
" NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
" NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"\n",
"After cuDNN 9.0:\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
"```\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code> and <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> are only supported in PyTorch, not JAX or PaddlePaddle.\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"</div>\n",
"\n",
"### 2.3 Example Tests\n",
......@@ -290,7 +295,7 @@
},
{
"cell_type": "markdown",
"id": "3ad85b86",
"id": "ccd5650d",
"metadata": {},
"source": [
"## 3. Backend Support\n",
......@@ -311,7 +316,7 @@
},
{
"cell_type": "markdown",
"id": "37920af4",
"id": "8439b389",
"metadata": {},
"source": [
"### 3.1 QKV Layout\n",
......@@ -354,7 +359,7 @@
},
{
"cell_type": "markdown",
"id": "94c69fae",
"id": "0290f8e9",
"metadata": {},
"source": [
"### 3.2 Attention Mask\n",
......@@ -387,7 +392,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "4c87df64",
"id": "b1b7cdd4",
"metadata": {},
"outputs": [
{
......@@ -408,7 +413,7 @@
},
{
"cell_type": "markdown",
"id": "5ec0c75d",
"id": "e045c284",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n",
......@@ -432,7 +437,7 @@
},
{
"cell_type": "markdown",
"id": "3f8f6f1c",
"id": "8b8a4e40",
"metadata": {},
"source": [
"### 3.4 FP8 Attention\n",
......
......@@ -22,6 +22,9 @@ from transformer_engine.pytorch.attention import (
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
check_set_window_size,
AttentionParams,
_attention_backends,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -84,6 +87,7 @@ class ModelConfig:
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -100,6 +104,7 @@ class ModelConfig:
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
def _get_attention_backends(
......@@ -118,6 +123,8 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
......@@ -136,33 +143,10 @@ def _get_attention_backends(
fused_attn_backends = []
available_backends = None
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
......@@ -184,13 +168,18 @@ def _get_attention_backends(
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
......@@ -211,18 +200,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
ml = ~ml
return w, ml
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -251,15 +228,22 @@ def test_dot_product_attention(
pytest.skip("No need to test this layout for cross attention")
# Test backend availability
window_size = (2, 2) if swa else (-1, -1)
window_size = (-1, -1)
if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist())
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=window_size,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
flash_attn_supported = True
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
......@@ -268,9 +252,6 @@ def test_dot_product_attention(
is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
......@@ -278,12 +259,9 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if swa:
config.attn_mask_type = attn_mask_type
# FusedAttention backend
if fused_attn_supported:
......@@ -295,7 +273,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -308,7 +285,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -320,7 +296,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -334,7 +309,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -499,11 +473,19 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
}
......@@ -622,7 +604,6 @@ def _run_dot_product_attention(
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
......@@ -637,6 +618,8 @@ def _run_dot_product_attention(
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -733,11 +716,6 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
......@@ -902,7 +880,7 @@ def _run_dot_product_attention(
q,
k,
v,
window_size=window_size,
window_size=config.window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
......@@ -1121,6 +1099,8 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
inp = torch.randn(
......@@ -1279,6 +1259,9 @@ def _rmse(a, b):
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
......@@ -1457,6 +1440,9 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
......@@ -1741,6 +1727,8 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
-100,
......@@ -1794,6 +1782,8 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
......
......@@ -72,7 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -116,7 +117,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0))) {
flag_m512 = true;
}
if ( // architecture
......@@ -165,7 +167,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) ||
(qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
(qkv_format == NVTE_QKV_Format::NVTE_BSHD)) &&
// sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -213,6 +227,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -242,9 +257,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type,
dropout, h, h, max_seqlen, max_seqlen, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -259,8 +274,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -286,7 +302,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
......@@ -317,9 +334,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type,
dropout, h, h, max_seqlen, max_seqlen, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -341,9 +358,10 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -375,6 +393,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
......@@ -409,9 +428,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -426,9 +445,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -454,7 +474,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -491,9 +512,9 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -517,10 +538,10 @@ void nvte_fused_attn_bwd_kvpacked(
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV,
input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -553,7 +574,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -579,9 +601,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -596,9 +618,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -625,7 +648,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -655,9 +680,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -681,10 +706,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K,
input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
......@@ -50,8 +50,9 @@ namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
......@@ -63,6 +64,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -70,6 +75,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
try {
FADescriptor_v1 descriptor{b,
......@@ -86,6 +95,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout,
bias_type,
mask_type,
window_size_left,
window_size_right,
true,
tensorType,
tensorType};
......@@ -208,6 +220,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_options.set_sliding_window_length(window_size_left);
}
sdpa_options.set_alibi_mask(is_alibi);
if (is_bias) {
......@@ -367,7 +383,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrKTranspose,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
......@@ -381,10 +398,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
try {
FADescriptor_v1 descriptor{b,
......@@ -401,6 +428,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout,
bias_type,
mask_type,
window_size_left,
window_size_right,
deterministic,
tensorType,
tensorType};
......@@ -552,6 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_backward_options.set_sliding_window_length(window_size_left);
}
if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
sdpa_backward_options.set_alibi_mask(is_alibi);
if (is_bias) {
......@@ -746,7 +784,8 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -827,9 +866,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
......@@ -850,6 +889,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
......@@ -902,11 +942,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -926,7 +966,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1011,10 +1052,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1035,11 +1077,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1089,11 +1132,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1110,17 +1153,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
}
}
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1193,10 +1234,11 @@ void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1217,9 +1259,10 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1260,11 +1303,11 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -21,13 +21,15 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
......@@ -37,7 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -46,31 +49,31 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......
......@@ -1687,6 +1687,9 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
0,
0,
true,
fwd_tensor_type,
fwd_tensor_type};
......@@ -1981,6 +1984,9 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
0,
0,
false,
fwd_tensor_type,
bwd_tensor_type};
......
......@@ -100,16 +100,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, bias_type, fwd_tensor_type,
bwd_tensor_type) <
dropoutProbability, layout, mask_type, window_size_left, window_size_right,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
};
......
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -135,22 +137,25 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim);
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -197,6 +202,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -206,6 +213,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
......@@ -248,6 +256,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -258,7 +269,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
......@@ -310,6 +322,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -321,6 +336,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
......@@ -369,6 +385,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -379,7 +398,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -435,6 +455,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -446,7 +468,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -499,6 +522,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -510,7 +536,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -162,5 +162,6 @@ class DelayedScaling:
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
f"reduce_amax={self.reduce_amax}"
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
......@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim);
head_dim, -1, -1);
return backend;
}
......@@ -154,22 +154,22 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
......@@ -258,7 +258,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -277,11 +277,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -294,7 +295,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
bias_type, mask_type, -1, -1, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -305,12 +306,13 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -373,13 +375,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -388,8 +390,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......@@ -399,8 +401,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -487,7 +489,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......@@ -509,13 +511,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -542,7 +544,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -577,7 +580,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
......
......@@ -134,7 +134,7 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim);
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1);
return fused_attention_backend;
}
......
......@@ -673,7 +673,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -687,7 +687,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -758,7 +758,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -769,7 +769,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -840,12 +840,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -855,12 +855,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -935,24 +935,24 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1042,7 +1042,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
......@@ -1058,7 +1058,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
......@@ -1140,7 +1140,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1152,7 +1152,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
......@@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
from dataclasses import dataclass, fields
import numpy as np
from packaging.version import Version as PkgVersion
......@@ -103,45 +104,24 @@ logging.basicConfig(
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
"use_fused_attention": None,
"fused_attention_backend": None,
"use_unfused_attention": None,
"backend_selection_requires_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def get_attention_backend(
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor,
qkv_dtype: torch.dtype = torch.bfloat16,
qkv_layout: str = "sbh3d",
batch_size: int = 1,
num_heads: int = 16,
num_gqa_groups: int = 16,
max_seqlen_q: int = 128,
max_seqlen_kv: int = 128,
head_dim: int = 64,
attn_mask_type: str = "no_mask",
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes_shape: Optional[Union[torch.Size, List]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias_shape: str = "1hss",
core_attention_bias_requires_grad: bool = True,
pad_between_seqs: bool = False,
attention_dropout: float = 0.0,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
):
@dataclass(eq=True)
class AttentionParams:
"""
Select an attention backend based on the user input and runtime environment.
Attention parameters used to determine which backend to be used.
Parameters
----------
......@@ -166,7 +146,7 @@ def get_attention_backend(
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = (-1, -1)
window_size: Tuple[int, int], default = None
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
......@@ -185,10 +165,62 @@ def get_attention_backend(
Whether context parallelism is used or not.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
qkv_dtype: torch.dtype = torch.bfloat16
qkv_layout: str = "sbh3d"
batch_size: int = 1
num_heads: int = 16
num_gqa_groups: int = 16
max_seqlen_q: int = 128
max_seqlen_kv: int = 128
head_dim: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
core_attention_bias_requires_grad: bool = True
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_bottom_right_alignment": True,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def get_attention_backend(
attention_params: AttentionParams = None,
):
"""
Select the appropriate attention backend/sub-backend based on user input and runtime environment.
Parameters
----------
See `AttentionParams`.
Returns
----------
......@@ -196,20 +228,66 @@ def get_attention_backend(
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, the `FusedAttention` sub-backend, else `None`.
available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
qkv_type = attention_params.qkv_type
qkv_dtype = attention_params.qkv_dtype
qkv_layout = attention_params.qkv_layout
batch_size = attention_params.batch_size
num_heads = attention_params.num_heads
num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv
head_dim = attention_params.head_dim
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
# Run config
logger = logging.getLogger("DotProductAttention")
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
"transformer_engine_version": te.__version__,
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
),
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in cudnn_version]),
}
attention_params_dict = {
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
if fp8:
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
use_flash_attention = _NVTE_FLASH_ATTN
use_fused_attention = _NVTE_FUSED_ATTN
use_unfused_attention = _NVTE_UNFUSED_ATTN
if not use_flash_attention:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
......@@ -227,7 +305,6 @@ def get_attention_backend(
use_fused_attention = False
# Filter: Compute capability
device_compute_capability = get_device_compute_capability()
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
......@@ -325,12 +402,6 @@ def get_attention_backend(
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if use_unfused_attention and attn_mask_type == "causal" and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling UnfusedDotProductAttention for "
"top-left-diagonal causal masks for cross-attention"
)
use_unfused_attention = False
if (
use_flash_attention
and _flash_attn_2_1_plus
......@@ -357,17 +428,58 @@ def get_attention_backend(
use_flash_attention = False
# Filter: Sliding window attention
if window_size is not None and window_size[0] != -1 and window_size[1] not in [-1, 0]:
if use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as "
"it does not support sliding window attention"
)
use_unfused_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it does not support sliding window attention")
use_fused_attention = False
if use_flash_attention and (not _flash_attn_2_3_plus or context_parallel):
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
)
use_fused_attention = False
elif context_parallel:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with context parallelism"
)
use_fused_attention = False
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
"no_mask",
"padding",
"causal_bottom_right",
"padding_causal_bottom_right",
]:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with attn_mask_type = %s for cross-attention",
attn_mask_type,
)
use_fused_attention = False
elif "padding" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with attn_mask_type = %s",
attn_mask_type,
)
use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and (not _flash_attn_2_3_plus or context_parallel)
):
logger.debug(
"Disabling FlashAttention as sliding window attention requires "
"flash-attn 2.3+ and no context parallelism"
......@@ -375,6 +487,14 @@ def get_attention_backend(
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | no_bias, alibi/alibi_slopes | bottom right
# FusedAttention | no_bias, post_scale_bias |
# | alibi/alibi_slopes | top left,
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
......@@ -388,18 +508,20 @@ def get_attention_backend(
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and alibi_slopes_shape is not None
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if (
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads
):
fu_core_attention_bias_shape = "bhss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
if (
use_fused_attention
......@@ -435,14 +557,41 @@ def get_attention_backend(
max_seqlen_q,
max_seqlen_kv,
head_dim,
window_size[0],
window_size[1],
)
if fused_attention_backend == FusedAttnBackend["No_Backend"] or (
context_parallel and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
elif (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
fused_attention_backend = None
if (
use_fused_attention
and context_parallel
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"context parallellism",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and window_size is not None
and window_size[0] != -1
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"slidng window attention",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
......@@ -451,6 +600,7 @@ def get_attention_backend(
" [1, H, S, S] shape"
)
use_fused_attention = False
fused_attention_backend = None
# Filter: Determinism
# backend | deterministic
......@@ -471,22 +621,36 @@ def get_attention_backend(
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
if (
use_fused_attention
and (
fused_attention_backend == FusedAttnBackend["FP8"]
or (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and device_compute_capability < (9, 0)
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
and (
device_compute_capability < (9, 0)
or core_attention_bias_requires_grad
or cudnn_version < (8, 9, 5)
)
)
and deterministic
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
if fused_attention_backend is not None
else ""
),
bool(available_backends[2]),
)
# Select FusedAttention for performance
if (
......@@ -507,29 +671,28 @@ def get_attention_backend(
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
else:
selected_backend = "NoBackend"
logger.debug("Selected backend = %s", selected_backend)
logger.debug(
"Available backends: FlashAttention=%s, FusedAttention=%s, UnfusedDotProductAttention=%s",
bool(available_backends[0]),
bool(available_backends[1]),
bool(available_backends[2]),
)
logger.debug("Selected backend: %s", selected_backend)
global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return (
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
fused_attention_backend,
)
......@@ -579,6 +742,64 @@ class InferenceParams: # pylint: disable=too-few-public-methods
)
@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
and for other mask types, the bottom right corner.
Parameters
----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input.
Returns
----------
attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask.
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
else, the same shape as input `attention_mask`.
"""
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
if attn_mask_type in ["causal"]:
left = window_size[0] if window_size[0] != -1 else max_seqlen_q
right = window_size[1] if window_size[1] != -1 else max_seqlen_q
mask_upper = torch.triu(mask, diagonal=-left)
mask_lower = torch.tril(mask_upper, diagonal=right)
else:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
attn_mask_type = "arbitrary"
mask = mask_lower.logical_not()
if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask)
return attn_mask_type, mask
@torch.no_grad()
def get_alibi(
num_heads: int,
......@@ -586,6 +807,7 @@ def get_alibi(
max_seqlen_kv: int,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
......@@ -600,6 +822,9 @@ def get_alibi(
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns
----------
......@@ -635,15 +860,21 @@ def get_alibi(
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if bottom_right_alignment:
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
else:
bias = torch.arange(
1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
).view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
)
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
_alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
......@@ -2570,7 +2801,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
assert core_attention_bias is not None, "core_attention_bias should not be None!"
if core_attention_bias_type == "alibi":
_, core_attention_bias = get_alibi(
output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes
output_size[1],
output_size[2],
output_size[3],
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
matmul_result = torch.baddbmm(
matmul_result,
......@@ -2892,8 +3127,6 @@ class FlashAttention(torch.nn.Module):
) -> torch.Tensor:
"""flash-attn fprop"""
window_size = check_set_window_size(attn_mask_type, window_size)
assert (
query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16]
......@@ -3120,11 +3353,13 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8:
......@@ -3168,6 +3403,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
if fp8_meta["recipe"].fp8_mha:
......@@ -3233,6 +3469,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
fp8_tensors = (None, None, None, None)
......@@ -3252,10 +3489,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
)
ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret
......@@ -3357,6 +3596,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
if ctx.fp8_meta["recipe"].fp8_mha:
dqkv = Float8Tensor(
......@@ -3409,6 +3650,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
......@@ -3434,6 +3677,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
......@@ -3457,6 +3702,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -3483,11 +3730,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8:
......@@ -3540,6 +3789,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
if fp8_meta["recipe"].fp8_mha:
......@@ -3613,6 +3863,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
out_save = out_ret
......@@ -3639,10 +3890,12 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
)
ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret
......@@ -3752,6 +4005,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(
......@@ -3823,6 +4078,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
......@@ -3852,6 +4109,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
......@@ -3879,6 +4138,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -3906,11 +4167,13 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc")
if fp8:
......@@ -3985,6 +4248,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
if fp8_meta["recipe"].fp8_mha:
......@@ -4108,6 +4372,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
window_size,
rng_gen,
)
out_save = out_ret
......@@ -4143,10 +4408,12 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
)
ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret
......@@ -4261,6 +4528,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
if ctx.fp8_meta["recipe"].fp8_mha:
......@@ -4385,6 +4654,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
......@@ -4415,6 +4686,8 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
......@@ -4443,6 +4716,8 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -4494,21 +4769,7 @@ class FusedAttention(torch.nn.Module):
"NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number
if deterministic:
# workspace optimization path is deterministic
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
self.deterministic = deterministic
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -4546,6 +4807,7 @@ class FusedAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -4698,11 +4960,13 @@ class FusedAttention(torch.nn.Module):
qkv_layout,
core_attention_bias_type,
attn_mask_type,
window_size,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
self.deterministic,
)
# ...hd -> ...(hd)
......@@ -4772,7 +5036,9 @@ class DotProductAttention(TransformerEngineBaseModule):
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
......@@ -4875,35 +5141,29 @@ class DotProductAttention(TransformerEngineBaseModule):
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(kv_channels)
self.device_compute_capability = get_device_compute_capability()
self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
self.use_flash_attention = int(
os.getenv("NVTE_FLASH_ATTN", "1")
) and self.device_compute_capability >= (8, 0)
if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0:
self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if self.device_compute_capability < (8, 0):
self.logger.debug("Disabling FlashAttention for compute capability < sm80")
if not _flash_attn_2_4_1_plus and self.deterministic:
self.use_flash_attention = False
self.logger.warning(
"Disabling usage of FlashAttention since version <2.4.1 does not support "
"deterministic execution. In order to use FA with deterministic behavior,"
" please install FlashAttention version >=2.4.1."
)
self.use_fused_attention = int(
os.getenv("NVTE_FUSED_ATTN", "1")
) and self.device_compute_capability >= (8, 0)
if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0:
self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if self.device_compute_capability < (8, 0):
self.logger.debug("Disabling FusedAttention for compute capability < sm80")
# To use the workspace optimization path for determinism, please
# set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
# and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
cudnn_version = get_cudnn_version()
if (8, 9, 5) <= cudnn_version < (9, 0, 0):
if self.deterministic:
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
......@@ -4915,25 +5175,23 @@ class DotProductAttention(TransformerEngineBaseModule):
"attention_dropout_ctx": attention_dropout_ctx,
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(
softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
)
self.flash_attention = FlashAttention(
softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
)
# Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(
softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
)
self.fused_attention = FusedAttention(
softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
)
self.unfused_attention = UnfusedDotProductAttention(
softmax_scale, **attn_kwargs, layer_number=layer_number
......@@ -5162,11 +5420,13 @@ class DotProductAttention(TransformerEngineBaseModule):
) as query_layer:
if self.fp8:
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
self.logger.WARNING(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
......@@ -5308,6 +5568,28 @@ class DotProductAttention(TransformerEngineBaseModule):
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
if cu_seqlens_q is None or cu_seqlens_kv is None:
if "padding" in attn_mask_type:
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if max_seqlen_q == max_seqlen_kv:
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_q = _get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
cu_seqlens_kv = _get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
)
if (
isinstance(query_layer, Float8Tensor)
......@@ -5330,6 +5612,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
......@@ -5338,16 +5621,12 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
or _alibi_cache["_alibi_slopes"] is None
):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
......@@ -5383,13 +5662,7 @@ class DotProductAttention(TransformerEngineBaseModule):
and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
(
use_flash_attention,
use_fused_attention,
use_unfused_attention,
_,
fused_attention_backend,
) = get_attention_backend(
attention_params = AttentionParams(
qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout,
......@@ -5410,42 +5683,42 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
deterministic=deterministic,
deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
run_config = {
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(
self.device_compute_capability[0], self.device_compute_capability[1]
global _attention_backends
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = get_attention_backend(attention_params)
if use_flash_attention:
self.logger.info("Running with FlashAttention backend")
elif use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
),
"q_dtype": query_layer.dtype,
"k_dtype": key_layer.dtype,
"v_dtype": value_layer.dtype,
"q_shape": list(query_layer.shape),
"k_shape": list(key_layer.shape),
"v_shape": list(value_layer.shape),
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": (
core_attention_bias.shape if core_attention_bias is not None else None
),
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
elif use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
else:
use_flash_attention = _attention_backends["use_flash_attention"]
use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]
if use_flash_attention:
self.logger.info("Running with FlashAttention backend ")
self.logger.debug("Running with config=%s", run_config)
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
query_layer.shape[-2],
......@@ -5472,23 +5745,11 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
if self.fp8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
self.logger.debug("Running with config=%s", run_config)
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and alibi_slopes is not None:
if core_attention_bias_type == "alibi" and (
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
......@@ -5496,6 +5757,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
......@@ -5512,6 +5774,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -5535,6 +5798,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -5555,8 +5819,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
self.logger.debug("Running with config=%s", run_config)
if window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -5636,7 +5904,9 @@ class MultiheadAttention(torch.nn.Module):
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
......
......@@ -100,6 +100,7 @@ def fused_attn_fwd_qkvpacked(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed QKV input.
......@@ -152,6 +153,11 @@ def fused_attn_fwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -236,6 +242,7 @@ def fused_attn_fwd_qkvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens,
qkv,
qkv_dtype,
......@@ -282,6 +289,8 @@ def fused_attn_bwd_qkvpacked(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input.
......@@ -346,6 +355,13 @@ def fused_attn_bwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -393,6 +409,8 @@ def fused_attn_bwd_qkvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens,
qkv,
o,
......@@ -441,6 +459,7 @@ def fused_attn_fwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed KV input.
......@@ -503,6 +522,11 @@ def fused_attn_fwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -588,6 +612,7 @@ def fused_attn_fwd_kvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -641,6 +666,8 @@ def fused_attn_bwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
......@@ -716,6 +743,13 @@ def fused_attn_bwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -766,6 +800,8 @@ def fused_attn_bwd_kvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -818,6 +854,7 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -886,6 +923,11 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -971,6 +1013,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -1026,6 +1069,8 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
......@@ -1106,6 +1151,13 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -1158,6 +1210,8 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
q,
......
......@@ -14,30 +14,29 @@
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
......@@ -48,8 +47,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -61,10 +61,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -76,9 +76,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
......@@ -89,10 +90,10 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......
......@@ -10,17 +10,15 @@ constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim, window_size_left, window_size_right);
return fused_attention_backend;
}
......@@ -82,12 +80,13 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
......@@ -173,11 +172,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -213,11 +212,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
}
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -229,10 +228,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
......@@ -351,11 +350,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(),
max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -363,11 +362,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(),
max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -379,8 +378,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -487,8 +487,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -528,8 +528,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -542,10 +542,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -690,12 +690,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout,
bias_type, attn_mask_type, window_size[0], window_size[1],
deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -703,12 +704,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout,
bias_type, attn_mask_type, window_size[0], window_size[1],
deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -720,9 +722,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
......@@ -833,7 +836,8 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -874,7 +878,8 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -887,10 +892,10 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -1115,7 +1120,8 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -1128,7 +1134,8 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
......@@ -21,18 +21,18 @@ THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_shape = (sq, sk)
if matrix_shape not in _default_causal_mask:
diagonal_offset = sk - sq + 1
_default_causal_mask[matrix_shape] = torch.triu(
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
return _default_causal_mask[matrix_shape]
return _default_causal_mask[matrix_identifiers]
def _get_onnx_export_causal_mask(
......@@ -332,10 +332,8 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True
if self.attn_mask_type == "causal" and sq != sk:
return False # Fused causal kernel only support causal_bottom_right
if (
sq % 4 == 0 # sq must be divisor of 4
......@@ -387,7 +385,7 @@ class FusedScaleMaskSoftmax(nn.Module):
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(seq_len_q, seq_len_k)
mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
......
......@@ -142,9 +142,11 @@ class TransformerLayer(torch.nn.Module):
sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
no sliding window and "`causal`" mask specifically. Similar to
:attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
no sliding window and causal mask specifically. Both `causal` and
`causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
:attr:`window_size` in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment