Unverified Commit 8e039fdc authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add cuDNN sliding window and set_deterministic_algorithm (#992)



* add cuDNN swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix SWA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add set_deterministic and minor fixes for swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add AttentionParams
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change window_size to int64_t; fix swa/determinism tests; cache _attention_backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add window_size to get_backend; fix jax and paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; add set_deter to bwd_impl
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 tests due to determinism
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support matrix for SWA and bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add wording on window_size special cases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak on wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax assertion error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* call bwd with deterministic=true for jax/paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism words in documentation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 166bb078
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "141fa8bd", "id": "8ae3bc43",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Attention Is All You Need!\n", "# Attention Is All You Need!\n",
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "09a60057", "id": "47421c01",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 1. Attention Backends\n", "## 1. Attention Backends\n",
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f387274e", "id": "e52f60f0",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1.1 Flash vs. Non-Flash\n", "### 1.1 Flash vs. Non-Flash\n",
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f1389145", "id": "bb909ac4",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1.2 flash-attention\n", "### 1.2 flash-attention\n",
...@@ -97,7 +97,7 @@ ...@@ -97,7 +97,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "bbc5c73f", "id": "9a380859",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -113,7 +113,7 @@ ...@@ -113,7 +113,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "173638b6", "id": "0584bb01",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -140,7 +140,7 @@ ...@@ -140,7 +140,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "0f62d2fa", "id": "45e53fc9",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 2. Backend Selection\n", "## 2. Backend Selection\n",
...@@ -160,7 +160,7 @@ ...@@ -160,7 +160,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "86e16a2b", "id": "6dfeade3",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2.1 Debug Information\n", "### 2.1 Debug Information\n",
...@@ -177,7 +177,7 @@ ...@@ -177,7 +177,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e439434e", "id": "7e3b7981",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." "The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
...@@ -186,7 +186,7 @@ ...@@ -186,7 +186,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 22,
"id": "9d002327", "id": "961c51d4",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -210,7 +210,7 @@ ...@@ -210,7 +210,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "bbf1756c", "id": "11bfbbd7",
"metadata": {}, "metadata": {},
"source": [ "source": [
"To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, " "To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
...@@ -219,7 +219,7 @@ ...@@ -219,7 +219,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 25,
"id": "66a2f34c", "id": "162a2be1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -249,7 +249,7 @@ ...@@ -249,7 +249,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "9f964732", "id": "779a51e6",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2.2 User Control\n", "### 2.2 User Control\n",
...@@ -274,11 +274,16 @@ ...@@ -274,11 +274,16 @@
"\n", "\n",
"Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n", "Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
"```\n", "```\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n", "Before cuDNN 9.0:\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n", " NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
" NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"\n",
"After cuDNN 9.0:\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
"```\n", "```\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code> and <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> are only supported in PyTorch, not JAX or PaddlePaddle.\n", "<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"</div>\n", "</div>\n",
"\n", "\n",
"### 2.3 Example Tests\n", "### 2.3 Example Tests\n",
...@@ -290,7 +295,7 @@ ...@@ -290,7 +295,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "3ad85b86", "id": "ccd5650d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 3. Backend Support\n", "## 3. Backend Support\n",
...@@ -311,7 +316,7 @@ ...@@ -311,7 +316,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "37920af4", "id": "8439b389",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.1 QKV Layout\n", "### 3.1 QKV Layout\n",
...@@ -354,7 +359,7 @@ ...@@ -354,7 +359,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "94c69fae", "id": "0290f8e9",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.2 Attention Mask\n", "### 3.2 Attention Mask\n",
...@@ -387,7 +392,7 @@ ...@@ -387,7 +392,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"id": "4c87df64", "id": "b1b7cdd4",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -408,7 +413,7 @@ ...@@ -408,7 +413,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "5ec0c75d", "id": "e045c284",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n", "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n",
...@@ -432,7 +437,7 @@ ...@@ -432,7 +437,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "3f8f6f1c", "id": "8b8a4e40",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.4 FP8 Attention\n", "### 3.4 FP8 Attention\n",
......
...@@ -22,6 +22,9 @@ from transformer_engine.pytorch.attention import ( ...@@ -22,6 +22,9 @@ from transformer_engine.pytorch.attention import (
get_attention_backend, get_attention_backend,
_flash_attn_2_plus, _flash_attn_2_plus,
_flash_attn_2_3_plus, _flash_attn_2_3_plus,
check_set_window_size,
AttentionParams,
_attention_backends,
) )
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -84,6 +87,7 @@ class ModelConfig: ...@@ -84,6 +87,7 @@ class ModelConfig:
alibi_type: str = "none", alibi_type: str = "none",
num_layers: int = 1, num_layers: int = 1,
bias_shape: str = "1hss", bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -100,6 +104,7 @@ class ModelConfig: ...@@ -100,6 +104,7 @@ class ModelConfig:
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape self.bias_shape = bias_shape
self.window_size = window_size
def _get_attention_backends( def _get_attention_backends(
...@@ -118,6 +123,8 @@ def _get_attention_backends( ...@@ -118,6 +123,8 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
...@@ -136,8 +143,10 @@ def _get_attention_backends( ...@@ -136,8 +143,10 @@ def _get_attention_backends(
fused_attn_backends = [] fused_attn_backends = []
available_backends = None available_backends = None
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" fused_attention_backend = None
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype, qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
batch_size=config.batch_size, batch_size=config.batch_size,
...@@ -159,37 +168,17 @@ def _get_attention_backends( ...@@ -159,37 +168,17 @@ def _get_attention_backends(
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
) )
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: _, _, fused_attention_backend, _, available_backends = get_attention_backend(
fused_attn_backends.append(fused_attention_backend) attention_params
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
) )
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: return available_backends, fused_attention_backend
fused_attn_backends.append(fused_attention_backend)
elif ( backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
fused_attention_backend != FusedAttnBackend["No_Backend"] for i in range(3):
and fused_attention_backend is not None os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
): _attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends return available_backends, fused_attn_backends
...@@ -211,18 +200,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -211,18 +200,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types_lean = [torch.bfloat16] param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
ml = ~ml
return w, ml
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
...@@ -251,15 +228,22 @@ def test_dot_product_attention( ...@@ -251,15 +228,22 @@ def test_dot_product_attention(
pytest.skip("No need to test this layout for cross attention") pytest.skip("No need to test this layout for cross attention")
# Test backend availability # Test backend availability
window_size = (2, 2) if swa else (-1, -1) window_size = (-1, -1)
if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist())
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=window_size, window_size=config.window_size,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
flash_attn_supported = True
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
...@@ -268,9 +252,6 @@ def test_dot_product_attention( ...@@ -268,9 +252,6 @@ def test_dot_product_attention(
is_training = config.head_dim <= 128 is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
...@@ -278,12 +259,9 @@ def test_dot_product_attention( ...@@ -278,12 +259,9 @@ def test_dot_product_attention(
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
if swa:
config.attn_mask_type = attn_mask_type
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
...@@ -295,7 +273,6 @@ def test_dot_product_attention( ...@@ -295,7 +273,6 @@ def test_dot_product_attention(
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -308,7 +285,6 @@ def test_dot_product_attention( ...@@ -308,7 +285,6 @@ def test_dot_product_attention(
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -320,7 +296,6 @@ def test_dot_product_attention( ...@@ -320,7 +296,6 @@ def test_dot_product_attention(
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -334,7 +309,6 @@ def test_dot_product_attention( ...@@ -334,7 +309,6 @@ def test_dot_product_attention(
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -504,6 +478,14 @@ model_configs_swa = { ...@@ -504,6 +478,14 @@ model_configs_swa = {
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
} }
...@@ -622,7 +604,6 @@ def _run_dot_product_attention( ...@@ -622,7 +604,6 @@ def _run_dot_product_attention(
ckpt_attn: bool, ckpt_attn: bool,
qkv_layout: str, qkv_layout: str,
workspace_opt: bool, workspace_opt: bool,
swa: bool,
pad_between_seqs: bool, pad_between_seqs: bool,
is_training: bool, is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
...@@ -637,6 +618,8 @@ def _run_dot_product_attention( ...@@ -637,6 +618,8 @@ def _run_dot_product_attention(
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens # Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
...@@ -733,11 +716,6 @@ def _run_dot_product_attention( ...@@ -733,11 +716,6 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"), attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"), attention_mask_kv.to(device="cuda"),
) )
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
alibi_slopes = None alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
...@@ -902,7 +880,7 @@ def _run_dot_product_attention( ...@@ -902,7 +880,7 @@ def _run_dot_product_attention(
q, q,
k, k,
v, v,
window_size=window_size, window_size=config.window_size,
attention_mask=attention_mask, attention_mask=attention_mask,
qkv_format=qkv_format, qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q, max_seqlen_q=config.max_seqlen_q,
...@@ -1121,6 +1099,8 @@ def _run_transformer_layer( ...@@ -1121,6 +1099,8 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor # Create input tensor
inp = torch.randn( inp = torch.randn(
...@@ -1279,6 +1259,9 @@ def _rmse(a, b): ...@@ -1279,6 +1259,9 @@ def _rmse(a, b):
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
...@@ -1457,6 +1440,9 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1457,6 +1440,9 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
pytest.skip("qkv_layout not applicable for MQA/GQA") pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
...@@ -1741,6 +1727,8 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1741,6 +1727,8 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint( inp = 0.0001 * torch.randint(
-100, -100,
...@@ -1794,6 +1782,8 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1794,6 +1782,8 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda") inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True inp.requires_grad = True
......
...@@ -72,7 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -72,7 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -116,7 +117,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -116,7 +117,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0))) {
flag_m512 = true; flag_m512 = true;
} }
if ( // architecture if ( // architecture
...@@ -165,7 +167,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -165,7 +167,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) || qkv_format == NVTE_QKV_Format::NVTE_THD) ||
(qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { (qkv_format == NVTE_QKV_Format::NVTE_BSHD)) &&
// sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))))) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
...@@ -213,6 +227,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -213,6 +227,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -242,9 +257,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -242,9 +257,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
dropout, h, h, max_seqlen, max_seqlen, d); max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -259,8 +274,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -259,8 +274,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -286,7 +302,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -286,7 +302,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -317,9 +334,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -317,9 +334,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
dropout, h, h, max_seqlen, max_seqlen, d); max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -341,9 +358,10 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -341,9 +358,10 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...@@ -375,6 +393,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -375,6 +393,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -409,9 +428,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -409,9 +428,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -426,9 +445,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -426,9 +445,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -454,7 +474,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -454,7 +474,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
...@@ -491,9 +512,9 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -491,9 +512,9 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -517,10 +538,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -517,10 +538,10 @@ void nvte_fused_attn_bwd_kvpacked(
} }
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention " "cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...@@ -553,7 +574,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -553,7 +574,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
...@@ -579,9 +601,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -579,9 +601,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -596,9 +618,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -596,9 +618,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -625,7 +648,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -625,7 +648,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
...@@ -655,9 +680,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -655,9 +680,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -681,10 +706,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -681,10 +706,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} }
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
...@@ -50,8 +50,9 @@ namespace fused_attn { ...@@ -50,8 +50,9 @@ namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
...@@ -63,6 +64,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -63,6 +64,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
...@@ -70,6 +75,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -70,6 +75,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) { if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
} }
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -86,6 +95,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -86,6 +95,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
window_size_left,
window_size_right,
true,
tensorType, tensorType,
tensorType}; tensorType};
...@@ -208,6 +220,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -208,6 +220,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_options.set_sliding_window_length(window_size_left);
}
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
if (is_bias) { if (is_bias) {
...@@ -367,7 +383,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -367,7 +383,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrKTranspose, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
...@@ -381,10 +398,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -381,10 +398,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -401,6 +428,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -401,6 +428,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
window_size_left,
window_size_right,
deterministic,
tensorType, tensorType,
tensorType}; tensorType};
...@@ -552,6 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -552,6 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_backward_options.set_sliding_window_length(window_size_left);
}
if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
sdpa_backward_options.set_alibi_mask(is_alibi); sdpa_backward_options.set_alibi_mask(is_alibi);
if (is_bias) { if (is_bias) {
...@@ -746,7 +784,8 @@ using namespace transformer_engine::fused_attn; ...@@ -746,7 +784,8 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -827,9 +866,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -827,9 +866,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
...@@ -850,6 +889,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -850,6 +889,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
...@@ -902,11 +942,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -902,11 +942,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -926,7 +966,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -926,7 +966,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1011,10 +1052,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1011,10 +1052,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1035,11 +1077,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1035,11 +1077,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -1089,11 +1132,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1089,11 +1132,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
&workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1110,16 +1153,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1110,16 +1153,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
} }
} }
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1193,10 +1234,11 @@ void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1193,10 +1234,11 @@ void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1217,9 +1259,10 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1217,9 +1259,10 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1260,11 +1303,11 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1260,11 +1303,11 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
&workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -21,13 +21,15 @@ namespace transformer_engine { ...@@ -21,13 +21,15 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
...@@ -37,7 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -37,7 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -46,31 +49,31 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -46,31 +49,31 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
......
...@@ -1687,6 +1687,9 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1687,6 +1687,9 @@ void fused_attn_fp8_fwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
0,
0,
true,
fwd_tensor_type, fwd_tensor_type,
fwd_tensor_type}; fwd_tensor_type};
...@@ -1981,6 +1984,9 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1981,6 +1984,9 @@ void fused_attn_fp8_bwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
0,
0,
false,
fwd_tensor_type, fwd_tensor_type,
bwd_tensor_type}; bwd_tensor_type};
......
...@@ -100,16 +100,20 @@ struct FADescriptor_v1 { ...@@ -100,16 +100,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
cudnn_frontend::DataType_t fwd_tensor_type; cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, bias_type, fwd_tensor_type, dropoutProbability, layout, mask_type, window_size_left, window_size_right,
bwd_tensor_type) < deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
}; };
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint>
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -146,11 +148,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); ...@@ -146,11 +148,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V. * \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V. * \param[in] head_dim The head dimension of Q, K, V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim); size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -197,6 +202,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -197,6 +202,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -206,6 +213,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -206,6 +213,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
...@@ -248,6 +256,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -248,6 +256,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -258,7 +269,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -258,7 +269,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream); int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
...@@ -310,6 +322,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -310,6 +322,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -321,6 +336,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -321,6 +336,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
...@@ -369,6 +385,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -369,6 +385,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -379,7 +398,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -379,7 +398,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream); int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -435,6 +455,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -435,6 +455,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -446,7 +468,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -446,7 +468,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream); int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -499,6 +522,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -499,6 +522,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -510,7 +536,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -510,7 +536,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream); NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -162,5 +162,6 @@ class DelayedScaling: ...@@ -162,5 +162,6 @@ class DelayedScaling:
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, " f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, " f"wgrad_override={self.override_linear_precision.wgrad}, "
f"reduce_amax={self.reduce_amax}" f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
) )
...@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim); head_dim, -1, -1);
return backend; return backend;
} }
...@@ -154,22 +154,22 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -154,22 +154,22 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr); mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr); bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
query_workspace_tensor.data(), nullptr); -1, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -258,7 +258,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -258,7 +258,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto backend = auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -277,11 +277,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -277,11 +277,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv = buffers[0]; auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_seq_offsets_tensor.data(), rng_state_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -294,7 +295,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -294,7 +295,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream); bias_type, mask_type, -1, -1, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -305,12 +306,13 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -305,12 +306,13 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto v = buffers[2]; auto v = buffers[2];
auto v_shape = k_shape; auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd( nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -373,13 +375,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -373,13 +375,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -388,8 +390,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -388,8 +390,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
query_workspace_tensor.data(), nullptr); -1, true, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
...@@ -399,8 +401,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -399,8 +401,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
query_workspace_tensor.data(), nullptr); -1, true, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -487,7 +489,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -487,7 +489,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto backend = auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias); rng_state, bias);
...@@ -509,13 +511,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -509,13 +511,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>()); std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream); cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -542,7 +544,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -542,7 +544,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -577,7 +580,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -577,7 +580,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
......
...@@ -134,7 +134,7 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -134,7 +134,7 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim); num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1);
return fused_attention_backend; return fused_attention_backend;
} }
......
...@@ -673,7 +673,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -673,7 +673,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream()); attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
...@@ -687,7 +687,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -687,7 +687,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream()); attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -758,7 +758,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -758,7 +758,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream()); attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
...@@ -769,7 +769,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -769,7 +769,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream()); attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -840,12 +840,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -840,12 +840,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), nvte_fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -855,12 +855,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -855,12 +855,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux); output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), nvte_fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -935,24 +935,24 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -935,24 +935,24 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); -1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); -1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1042,7 +1042,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1042,7 +1042,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream()); workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
...@@ -1058,7 +1058,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1058,7 +1058,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream()); workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
...@@ -1140,7 +1140,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1140,7 +1140,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1152,7 +1152,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1152,7 +1152,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
...@@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
import logging import logging
from dataclasses import dataclass, fields
import numpy as np import numpy as np
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
...@@ -103,45 +104,24 @@ logging.basicConfig( ...@@ -103,45 +104,24 @@ logging.basicConfig(
level=log_levels[log_level if log_level in [0, 1, 2] else 2], level=log_levels[log_level if log_level in [0, 1, 2] else 2],
) )
_alibi_cache = { _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
"_num_heads": None, _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
"_alibi_slopes": None, _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
"_max_seqlen_q": None,
"_max_seqlen_kv": None, _attention_backends = {
"_alibi_bias": None, "attention_params": None,
"_alibi_slopes_require_update": False, "use_flash_attention": None,
"_alibi_bias_require_update": False, "use_fused_attention": None,
"fused_attention_backend": None,
"use_unfused_attention": None,
"backend_selection_requires_update": False,
} }
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] @dataclass(eq=True)
class AttentionParams:
def get_attention_backend(
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor,
qkv_dtype: torch.dtype = torch.bfloat16,
qkv_layout: str = "sbh3d",
batch_size: int = 1,
num_heads: int = 16,
num_gqa_groups: int = 16,
max_seqlen_q: int = 128,
max_seqlen_kv: int = 128,
head_dim: int = 64,
attn_mask_type: str = "no_mask",
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes_shape: Optional[Union[torch.Size, List]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias_shape: str = "1hss",
core_attention_bias_requires_grad: bool = True,
pad_between_seqs: bool = False,
attention_dropout: float = 0.0,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
):
""" """
Select an attention backend based on the user input and runtime environment. Attention parameters used to determine which backend to be used.
Parameters Parameters
---------- ----------
...@@ -166,7 +146,7 @@ def get_attention_backend( ...@@ -166,7 +146,7 @@ def get_attention_backend(
attn_mask_type: str, default = `no_mask` attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = None
Sliding window attention size. Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
...@@ -185,10 +165,62 @@ def get_attention_backend( ...@@ -185,10 +165,62 @@ def get_attention_backend(
Whether context parallelism is used or not. Whether context parallelism is used or not.
deterministic: bool, default = `False` deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False` fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region. Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None` fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
qkv_dtype: torch.dtype = torch.bfloat16
qkv_layout: str = "sbh3d"
batch_size: int = 1
num_heads: int = 16
num_gqa_groups: int = 16
max_seqlen_q: int = 128
max_seqlen_kv: int = 128
head_dim: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
core_attention_bias_requires_grad: bool = True
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_bottom_right_alignment": True,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
def get_attention_backend(
attention_params: AttentionParams = None,
):
"""
Select the appropriate attention backend/sub-backend based on user input and runtime environment.
Parameters
----------
See `AttentionParams`.
Returns Returns
---------- ----------
...@@ -196,20 +228,66 @@ def get_attention_backend( ...@@ -196,20 +228,66 @@ def get_attention_backend(
Whether the `FlashAttention` backend has been selected. Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool use_fused_attention: bool
Whether the `FusedAttention` backend has been selected. Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool use_unfused_attention: bool
Whether the `UnfusedDotProductAttention` backend has been selected. Whether the `UnfusedDotProductAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, the `FusedAttention` sub-backend, else `None`.
available_backends: List[bool] available_backends: List[bool]
All available backends that could support the provided input. A list of Booleans All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
""" """
qkv_type = attention_params.qkv_type
qkv_dtype = attention_params.qkv_dtype
qkv_layout = attention_params.qkv_layout
batch_size = attention_params.batch_size
num_heads = attention_params.num_heads
num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv
head_dim = attention_params.head_dim
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
# Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
"transformer_engine_version": te.__version__,
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
),
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in cudnn_version]),
}
attention_params_dict = {
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
if fp8:
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)
# Filter: Environment variables # Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
use_flash_attention = _NVTE_FLASH_ATTN
use_fused_attention = _NVTE_FUSED_ATTN
use_unfused_attention = _NVTE_UNFUSED_ATTN
if not use_flash_attention: if not use_flash_attention:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention: if not use_fused_attention:
...@@ -227,7 +305,6 @@ def get_attention_backend( ...@@ -227,7 +305,6 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
# Filter: Compute capability # Filter: Compute capability
device_compute_capability = get_device_compute_capability()
if device_compute_capability < (8, 0): if device_compute_capability < (8, 0):
if use_flash_attention: if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+") logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
...@@ -325,12 +402,6 @@ def get_attention_backend( ...@@ -325,12 +402,6 @@ def get_attention_backend(
if use_unfused_attention and "padding" in attn_mask_type: if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type) logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False use_unfused_attention = False
if use_unfused_attention and attn_mask_type == "causal" and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling UnfusedDotProductAttention for "
"top-left-diagonal causal masks for cross-attention"
)
use_unfused_attention = False
if ( if (
use_flash_attention use_flash_attention
and _flash_attn_2_1_plus and _flash_attn_2_1_plus
...@@ -357,17 +428,58 @@ def get_attention_backend( ...@@ -357,17 +428,58 @@ def get_attention_backend(
use_flash_attention = False use_flash_attention = False
# Filter: Sliding window attention # Filter: Sliding window attention
if window_size is not None and window_size[0] != -1 and window_size[1] not in [-1, 0]: # backend | window_size | diagonal alignment
if use_unfused_attention: # ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug( logger.debug(
"Disabling UnfusedDotProductAttention as " "Disabling FusedAttention as it does not support sliding window attention"
"it does not support sliding window attention" " for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
)
use_fused_attention = False
elif context_parallel:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with context parallelism"
)
use_fused_attention = False
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
"no_mask",
"padding",
"causal_bottom_right",
"padding_causal_bottom_right",
]:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with attn_mask_type = %s for cross-attention",
attn_mask_type,
)
use_fused_attention = False
elif "padding" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with attn_mask_type = %s",
attn_mask_type,
) )
use_unfused_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it does not support sliding window attention")
use_fused_attention = False use_fused_attention = False
if use_flash_attention and (not _flash_attn_2_3_plus or context_parallel): if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and (not _flash_attn_2_3_plus or context_parallel)
):
logger.debug( logger.debug(
"Disabling FlashAttention as sliding window attention requires " "Disabling FlashAttention as sliding window attention requires "
"flash-attn 2.3+ and no context parallelism" "flash-attn 2.3+ and no context parallelism"
...@@ -375,6 +487,14 @@ def get_attention_backend( ...@@ -375,6 +487,14 @@ def get_attention_backend(
use_flash_attention = False use_flash_attention = False
# Filter: Attention bias # Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | no_bias, alibi/alibi_slopes | bottom right
# FusedAttention | no_bias, post_scale_bias |
# | alibi/alibi_slopes | top left,
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and ( if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"] core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None or core_attention_bias_shape is not None
...@@ -388,18 +508,20 @@ def get_attention_backend( ...@@ -388,18 +508,20 @@ def get_attention_backend(
if ( if (
use_fused_attention use_fused_attention
and core_attention_bias_type == "alibi" and core_attention_bias_type == "alibi"
and alibi_slopes_shape is not None and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
): ):
fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False fu_core_attention_bias_requires_grad = False
if ( if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2 len(alibi_slopes_shape) == 2
and alibi_slopes_shape[0] == batch_size and alibi_slopes_shape[0] == batch_size
and alibi_slopes_shape[1] == num_heads and alibi_slopes_shape[1] == num_heads
): ):
fu_core_attention_bias_shape = "bhss" fu_core_attention_bias_shape = "bhss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
if ( if (
use_fused_attention use_fused_attention
...@@ -435,14 +557,41 @@ def get_attention_backend( ...@@ -435,14 +557,41 @@ def get_attention_backend(
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
head_dim, head_dim,
window_size[0],
window_size[1],
) )
if fused_attention_backend == FusedAttnBackend["No_Backend"] or ( if fused_attention_backend == FusedAttnBackend["No_Backend"]:
context_parallel and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug("Disabling FusedAttention as no backend supports the provided input") logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False use_fused_attention = False
elif ( fused_attention_backend = None
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] if (
use_fused_attention
and context_parallel
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"context parallellism",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and window_size is not None
and window_size[0] != -1
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"slidng window attention",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss" and fu_core_attention_bias_shape != "1hss"
): ):
...@@ -451,6 +600,7 @@ def get_attention_backend( ...@@ -451,6 +600,7 @@ def get_attention_backend(
" [1, H, S, S] shape" " [1, H, S, S] shape"
) )
use_fused_attention = False use_fused_attention = False
fused_attention_backend = None
# Filter: Determinism # Filter: Determinism
# backend | deterministic # backend | deterministic
...@@ -471,22 +621,36 @@ def get_attention_backend( ...@@ -471,22 +621,36 @@ def get_attention_backend(
"please install flash-attn >= 2.4.1." "please install flash-attn >= 2.4.1."
) )
use_flash_attention = False use_flash_attention = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
if ( if (
use_fused_attention
and (
fused_attention_backend == FusedAttnBackend["FP8"]
or (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and device_compute_capability < (9, 0) and is_training
) and (
device_compute_capability < (9, 0)
or core_attention_bias_requires_grad
or cudnn_version < (8, 9, 5)
) )
and deterministic
): ):
logger.debug("Disabling FusedAttention for determinism reasons") logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False use_fused_attention = False
# All available backends # All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
if fused_attention_backend is not None
else ""
),
bool(available_backends[2]),
)
# Select FusedAttention for performance # Select FusedAttention for performance
if ( if (
...@@ -507,29 +671,28 @@ def get_attention_backend( ...@@ -507,29 +671,28 @@ def get_attention_backend(
use_unfused_attention = False use_unfused_attention = False
elif use_fused_attention: elif use_fused_attention:
use_unfused_attention = False use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention: if use_flash_attention:
selected_backend = "FlashAttention" selected_backend = "FlashAttention"
elif use_fused_attention: elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention: elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention" selected_backend = "UnfusedDotProductAttention"
else: logger.debug("Selected backend = %s", selected_backend)
selected_backend = "NoBackend"
logger.debug( global _attention_backends
"Available backends: FlashAttention=%s, FusedAttention=%s, UnfusedDotProductAttention=%s", _attention_backends["use_flash_attention"] = use_flash_attention
bool(available_backends[0]), _attention_backends["use_fused_attention"] = use_fused_attention
bool(available_backends[1]), _attention_backends["fused_attention_backend"] = fused_attention_backend
bool(available_backends[2]), _attention_backends["use_unfused_attention"] = use_unfused_attention
) _attention_backends["backend_selection_requires_update"] = False
logger.debug("Selected backend: %s", selected_backend)
return ( return (
use_flash_attention, use_flash_attention,
use_fused_attention, use_fused_attention,
fused_attention_backend,
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
fused_attention_backend,
) )
...@@ -579,6 +742,64 @@ class InferenceParams: # pylint: disable=too-few-public-methods ...@@ -579,6 +742,64 @@ class InferenceParams: # pylint: disable=too-few-public-methods
) )
@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
and for other mask types, the bottom right corner.
Parameters
----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input.
Returns
----------
attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask.
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
else, the same shape as input `attention_mask`.
"""
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
if attn_mask_type in ["causal"]:
left = window_size[0] if window_size[0] != -1 else max_seqlen_q
right = window_size[1] if window_size[1] != -1 else max_seqlen_q
mask_upper = torch.triu(mask, diagonal=-left)
mask_lower = torch.tril(mask_upper, diagonal=right)
else:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
attn_mask_type = "arbitrary"
mask = mask_lower.logical_not()
if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask)
return attn_mask_type, mask
@torch.no_grad() @torch.no_grad()
def get_alibi( def get_alibi(
num_heads: int, num_heads: int,
...@@ -586,6 +807,7 @@ def get_alibi( ...@@ -586,6 +807,7 @@ def get_alibi(
max_seqlen_kv: int, max_seqlen_kv: int,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None, bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Parameters Parameters
...@@ -600,6 +822,9 @@ def get_alibi( ...@@ -600,6 +822,9 @@ def get_alibi(
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None` bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32. Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns Returns
---------- ----------
...@@ -635,15 +860,21 @@ def get_alibi( ...@@ -635,15 +860,21 @@ def get_alibi(
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2: if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
if bottom_right_alignment:
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv 1, 1, 1, max_seqlen_kv
) )
else:
bias = torch.arange(
1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
).view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1 1, 1, max_seqlen_q, 1
) )
bias = bias.abs().mul(-1) bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
_alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False _alibi_cache["_alibi_bias_require_update"] = False
...@@ -2570,7 +2801,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -2570,7 +2801,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
assert core_attention_bias is not None, "core_attention_bias should not be None!" assert core_attention_bias is not None, "core_attention_bias should not be None!"
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
_, core_attention_bias = get_alibi( _, core_attention_bias = get_alibi(
output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes output_size[1],
output_size[2],
output_size[3],
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_result,
...@@ -2892,8 +3127,6 @@ class FlashAttention(torch.nn.Module): ...@@ -2892,8 +3127,6 @@ class FlashAttention(torch.nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
window_size = check_set_window_size(attn_mask_type, window_size)
assert ( assert (
query_layer.dtype in [torch.float16, torch.bfloat16] query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16] and key_layer.dtype in [torch.float16, torch.bfloat16]
...@@ -3120,11 +3353,13 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3120,11 +3353,13 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
fp8, fp8,
fp8_meta, fp8_meta,
deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc_qkvpacked") logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8: if fp8:
...@@ -3168,6 +3403,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3168,6 +3403,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
...@@ -3233,6 +3469,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3233,6 +3469,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
fp8_tensors = (None, None, None, None) fp8_tensors = (None, None, None, None)
...@@ -3252,10 +3489,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3252,10 +3489,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
) )
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret return out_ret
...@@ -3357,6 +3596,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3357,6 +3596,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
dqkv = Float8Tensor( dqkv = Float8Tensor(
...@@ -3409,6 +3650,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3409,6 +3650,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
...@@ -3434,6 +3677,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3434,6 +3677,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -3457,6 +3702,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3457,6 +3702,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -3483,11 +3730,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3483,11 +3730,13 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
fp8, fp8,
fp8_meta, fp8_meta,
deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc_kvpacked") logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8: if fp8:
...@@ -3540,6 +3789,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3540,6 +3789,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
...@@ -3613,6 +3863,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3613,6 +3863,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
out_save = out_ret out_save = out_ret
...@@ -3639,10 +3890,12 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3639,10 +3890,12 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
) )
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret return out_ret
...@@ -3752,6 +4005,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3752,6 +4005,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor( dq = Float8Tensor(
...@@ -3823,6 +4078,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3823,6 +4078,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
...@@ -3852,6 +4109,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3852,6 +4109,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -3879,6 +4138,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3879,6 +4138,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -3906,11 +4167,13 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3906,11 +4167,13 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
fp8, fp8,
fp8_meta, fp8_meta,
deterministic,
): ):
logger = logging.getLogger("FusedAttnFunc") logger = logging.getLogger("FusedAttnFunc")
if fp8: if fp8:
...@@ -3985,6 +4248,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3985,6 +4248,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
...@@ -4108,6 +4372,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4108,6 +4372,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
window_size,
rng_gen, rng_gen,
) )
out_save = out_ret out_save = out_ret
...@@ -4143,10 +4408,12 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4143,10 +4408,12 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
) )
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
return out_ret return out_ret
...@@ -4261,6 +4528,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4261,6 +4528,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
...@@ -4385,6 +4654,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4385,6 +4654,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.window_size,
ctx.deterministic,
) )
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
...@@ -4415,6 +4686,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4415,6 +4686,8 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -4443,6 +4716,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4443,6 +4716,8 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -4494,21 +4769,7 @@ class FusedAttention(torch.nn.Module): ...@@ -4494,21 +4769,7 @@ class FusedAttention(torch.nn.Module):
"NVTE_FUSED_ATTN_USE_FAv2_BWD", "0" "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
) == "1" and get_device_compute_capability() == (9, 0) ) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
if deterministic: self.deterministic = deterministic
# workspace optimization path is deterministic
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -4546,6 +4807,7 @@ class FusedAttention(torch.nn.Module): ...@@ -4546,6 +4807,7 @@ class FusedAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -4698,11 +4960,13 @@ class FusedAttention(torch.nn.Module): ...@@ -4698,11 +4960,13 @@ class FusedAttention(torch.nn.Module):
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
attn_mask_type, attn_mask_type,
window_size,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
fp8, fp8,
fp8_meta, fp8_meta,
self.deterministic,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
...@@ -4772,7 +5036,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4772,7 +5036,9 @@ class DotProductAttention(TransformerEngineBaseModule):
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well. be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self` attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`". type of attention, either "`self`" and "`cross`".
...@@ -4875,35 +5141,29 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4875,35 +5141,29 @@ class DotProductAttention(TransformerEngineBaseModule):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(kv_channels) softmax_scale = 1.0 / math.sqrt(kv_channels)
self.device_compute_capability = get_device_compute_capability()
self.deterministic = ( self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled() or torch.are_deterministic_algorithms_enabled()
) )
# To use the workspace optimization path for determinism, please
# set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
# and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
cudnn_version = get_cudnn_version()
if (8, 9, 5) <= cudnn_version < (9, 0, 0):
if self.deterministic:
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
self.use_flash_attention = int( # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
os.getenv("NVTE_FLASH_ATTN", "1") # - unset: enables workspace optimization when required workspace is <= 256MB
) and self.device_compute_capability >= (8, 0) # or when bias gradient needs to be computed
if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0: # - n: enables workspace optimization when required workspace is <= n bytes
self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") # - -1: enables workspace optimization always
if self.device_compute_capability < (8, 0): # - 0: disables workspace optimization always
self.logger.debug("Disabling FlashAttention for compute capability < sm80") if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
if not _flash_attn_2_4_1_plus and self.deterministic: os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
self.use_flash_attention = False if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
self.logger.warning( os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
"Disabling usage of FlashAttention since version <2.4.1 does not support "
"deterministic execution. In order to use FA with deterministic behavior,"
" please install FlashAttention version >=2.4.1."
)
self.use_fused_attention = int(
os.getenv("NVTE_FUSED_ATTN", "1")
) and self.device_compute_capability >= (8, 0)
if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0:
self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if self.device_compute_capability < (8, 0):
self.logger.debug("Disabling FusedAttention for compute capability < sm80")
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
...@@ -4915,7 +5175,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4915,7 +5175,6 @@ class DotProductAttention(TransformerEngineBaseModule):
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
} }
if self.use_flash_attention:
self.flash_attention = FlashAttention( self.flash_attention = FlashAttention(
softmax_scale, softmax_scale,
attention_type=attention_type, attention_type=attention_type,
...@@ -4926,7 +5185,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4926,7 +5185,6 @@ class DotProductAttention(TransformerEngineBaseModule):
# Instantiating three types since use of flash-attn and FusedAttention # Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention( self.fused_attention = FusedAttention(
softmax_scale, softmax_scale,
attention_type=attention_type, attention_type=attention_type,
...@@ -5162,11 +5420,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5162,11 +5420,13 @@ class DotProductAttention(TransformerEngineBaseModule):
) as query_layer: ) as query_layer:
if self.fp8: if self.fp8:
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha: if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa: if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)" self.logger.WARNING(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
...@@ -5308,6 +5568,28 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5308,6 +5568,28 @@ class DotProductAttention(TransformerEngineBaseModule):
seqlens_kv <= max_seqlen_kv seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!""" the sequence dimention in 'key_layer' and 'value_layer'!"""
if cu_seqlens_q is None or cu_seqlens_kv is None:
if "padding" in attn_mask_type:
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if max_seqlen_q == max_seqlen_kv:
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_q = _get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
cu_seqlens_kv = _get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
)
if ( if (
isinstance(query_layer, Float8Tensor) isinstance(query_layer, Float8Tensor)
...@@ -5330,6 +5612,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5330,6 +5612,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.layer_number == 1: if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
assert ( assert (
core_attention_bias is None core_attention_bias is None
...@@ -5338,16 +5621,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5338,16 +5621,12 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_num_heads"] != query_layer.shape[-2] _alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
or _alibi_cache["_alibi_slopes"] is None or _alibi_cache["_alibi_slopes"] is None
): ):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
context_parallel = ( context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
) )
...@@ -5383,13 +5662,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5383,13 +5662,7 @@ class DotProductAttention(TransformerEngineBaseModule):
and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
) )
( attention_params = AttentionParams(
use_flash_attention,
use_fused_attention,
use_unfused_attention,
_,
fused_attention_backend,
) = get_attention_backend(
qkv_type=type(query_layer), qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype, qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -5410,42 +5683,42 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5410,42 +5683,42 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
context_parallel=context_parallel, context_parallel=context_parallel,
deterministic=deterministic, deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
global _attention_backends
run_config = { if (
"compute_capability": "sm" _attention_backends["attention_params"] is None
+ str( or attention_params != _attention_backends["attention_params"]
(lambda x, y: x * 10 + y)( ):
self.device_compute_capability[0], self.device_compute_capability[1] _attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = get_attention_backend(attention_params)
if use_flash_attention:
self.logger.info("Running with FlashAttention backend")
elif use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
) )
), elif use_unfused_attention:
"q_dtype": query_layer.dtype, self.logger.info("Running with UnfusedDotProductAttention backend")
"k_dtype": key_layer.dtype, else:
"v_dtype": value_layer.dtype, use_flash_attention = _attention_backends["use_flash_attention"]
"q_shape": list(query_layer.shape), use_fused_attention = _attention_backends["use_fused_attention"]
"k_shape": list(key_layer.shape), fused_attention_backend = _attention_backends["fused_attention_backend"]
"v_shape": list(value_layer.shape), use_unfused_attention = _attention_backends["use_unfused_attention"]
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": (
core_attention_bias.shape if core_attention_bias is not None else None
),
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
if use_flash_attention: if use_flash_attention:
self.logger.info("Running with FlashAttention backend ")
self.logger.debug("Running with config=%s", run_config)
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi( alibi_slopes, _ = get_alibi(
query_layer.shape[-2], query_layer.shape[-2],
...@@ -5472,23 +5745,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5472,23 +5745,11 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_fused_attention: if use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
if self.fp8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
self.logger.debug("Running with config=%s", run_config)
fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and alibi_slopes is not None: if core_attention_bias_type == "alibi" and (
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi( _, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], query_layer.shape[-2],
...@@ -5496,6 +5757,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5496,6 +5757,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv, max_seqlen_kv,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
...@@ -5512,6 +5774,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5512,6 +5774,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias, core_attention_bias=fu_core_attention_bias,
...@@ -5535,6 +5798,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5535,6 +5798,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias, core_attention_bias=fu_core_attention_bias,
...@@ -5555,8 +5819,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5555,8 +5819,12 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_unfused_attention: if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend") if window_size is not None and (
self.logger.debug("Running with config=%s", run_config) window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -5636,7 +5904,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -5636,7 +5904,9 @@ class MultiheadAttention(torch.nn.Module):
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and "`causal`" mask specifically. Similar to :attr:`attn_mask_type`, it can window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well. be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None` num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
......
...@@ -100,6 +100,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -100,6 +100,7 @@ def fused_attn_fwd_qkvpacked(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed QKV input. """Fused Attention FWD for packed QKV input.
...@@ -152,6 +153,11 @@ def fused_attn_fwd_qkvpacked( ...@@ -152,6 +153,11 @@ def fused_attn_fwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -236,6 +242,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -236,6 +242,7 @@ def fused_attn_fwd_qkvpacked(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
cu_seqlens, cu_seqlens,
qkv, qkv,
qkv_dtype, qkv_dtype,
...@@ -282,6 +289,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -282,6 +289,8 @@ def fused_attn_bwd_qkvpacked(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input. """Fused Attention BWD for packed QKV input.
...@@ -346,6 +355,13 @@ def fused_attn_bwd_qkvpacked( ...@@ -346,6 +355,13 @@ def fused_attn_bwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns Returns
---------- ----------
...@@ -393,6 +409,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -393,6 +409,8 @@ def fused_attn_bwd_qkvpacked(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens, cu_seqlens,
qkv, qkv,
o, o,
...@@ -441,6 +459,7 @@ def fused_attn_fwd_kvpacked( ...@@ -441,6 +459,7 @@ def fused_attn_fwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d", qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed KV input. """Fused Attention FWD for packed KV input.
...@@ -503,6 +522,11 @@ def fused_attn_fwd_kvpacked( ...@@ -503,6 +522,11 @@ def fused_attn_fwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -588,6 +612,7 @@ def fused_attn_fwd_kvpacked( ...@@ -588,6 +612,7 @@ def fused_attn_fwd_kvpacked(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
q, q,
...@@ -641,6 +666,8 @@ def fused_attn_bwd_kvpacked( ...@@ -641,6 +666,8 @@ def fused_attn_bwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d", qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input. """Fused Attention BWD for packed KV input.
...@@ -716,6 +743,13 @@ def fused_attn_bwd_kvpacked( ...@@ -716,6 +743,13 @@ def fused_attn_bwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns Returns
---------- ----------
...@@ -766,6 +800,8 @@ def fused_attn_bwd_kvpacked( ...@@ -766,6 +800,8 @@ def fused_attn_bwd_kvpacked(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
q, q,
...@@ -818,6 +854,7 @@ def fused_attn_fwd( ...@@ -818,6 +854,7 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -886,6 +923,11 @@ def fused_attn_fwd( ...@@ -886,6 +923,11 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -971,6 +1013,7 @@ def fused_attn_fwd( ...@@ -971,6 +1013,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
q, q,
...@@ -1026,6 +1069,8 @@ def fused_attn_bwd( ...@@ -1026,6 +1069,8 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input. """Fused Attention BWD for packed KV input.
...@@ -1106,6 +1151,13 @@ def fused_attn_bwd( ...@@ -1106,6 +1151,13 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns Returns
---------- ----------
...@@ -1158,6 +1210,8 @@ def fused_attn_bwd( ...@@ -1158,6 +1210,8 @@ def fused_attn_bwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
q, q,
......
...@@ -14,30 +14,29 @@ ...@@ -14,30 +14,29 @@
* Attention * Attention
**************************************************************************************************/ **************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType kv_dtype, const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right);
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
...@@ -48,8 +47,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -48,8 +47,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
...@@ -61,10 +61,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -61,10 +61,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
...@@ -76,9 +76,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -76,9 +76,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
std::vector<at::Tensor> fused_attn_fwd( std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
...@@ -89,10 +90,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -89,10 +90,10 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<at::Tensor> fused_attn_bwd( std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......
...@@ -10,17 +10,15 @@ constexpr int block_size = 512; ...@@ -10,17 +10,15 @@ constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType kv_dtype, const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) {
size_t max_seqlen_q, size_t max_seqlen_kv, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
size_t head_dim) { static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
NVTE_Fused_Attn_Backend fused_attention_backend = attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), head_dim, window_size_left, window_size_right);
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend; return fused_attention_backend;
} }
...@@ -82,12 +80,13 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe ...@@ -82,12 +80,13 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec(); auto qkv_sizes = QKV.sizes().vec();
...@@ -173,11 +172,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -173,11 +172,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), nvte_fused_attn_fwd_qkvpacked(
&nvte_aux_tensor_pack, te_cu_seqlens.data(), te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -213,11 +212,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -213,11 +212,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} }
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), nvte_fused_attn_fwd_qkvpacked(
&nvte_aux_tensor_pack, te_cu_seqlens.data(), te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -229,10 +228,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -229,10 +228,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// fused attention BWD with packed QKV // fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
...@@ -351,11 +350,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -351,11 +350,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), nvte_fused_attn_bwd_qkvpacked(
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(),
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -363,11 +362,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -363,11 +362,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), nvte_fused_attn_bwd_qkvpacked(
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(),
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -379,8 +378,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -379,8 +378,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
...@@ -487,8 +487,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -487,8 +487,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -528,8 +528,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -528,8 +528,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -542,10 +542,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -542,10 +542,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
...@@ -690,12 +690,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -690,12 +690,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout,
workspace.data(), at::cuda::getCurrentCUDAStream()); bias_type, attn_mask_type, window_size[0], window_size[1],
deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -703,12 +704,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -703,12 +704,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout,
workspace.data(), at::cuda::getCurrentCUDAStream()); bias_type, attn_mask_type, window_size[0], window_size[1],
deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -720,9 +722,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -720,9 +722,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
std::vector<at::Tensor> fused_attn_fwd( std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
...@@ -833,7 +836,8 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -833,7 +836,8 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -874,7 +878,8 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -874,7 +878,8 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -887,10 +892,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -887,10 +892,10 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<at::Tensor> fused_attn_bwd( std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
...@@ -1115,7 +1120,8 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1115,7 +1120,8 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -1128,7 +1134,8 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1128,7 +1134,8 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
...@@ -21,18 +21,18 @@ THREADS_PER_BLOCK = 128 ...@@ -21,18 +21,18 @@ THREADS_PER_BLOCK = 128
_default_causal_mask = {} _default_causal_mask = {}
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor: def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input""" """Return the causal upper triangular mask for softmax input"""
if sq == 1: if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda") return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_shape = (sq, sk) matrix_identifiers = (mask_type, sq, sk)
if matrix_shape not in _default_causal_mask: if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_shape] = torch.triu( _default_causal_mask[matrix_identifiers] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
) )
return _default_causal_mask[matrix_shape] return _default_causal_mask[matrix_identifiers]
def _get_onnx_export_causal_mask( def _get_onnx_export_causal_mask(
...@@ -332,10 +332,8 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -332,10 +332,8 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary": if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported return False # Custom masks not supported
if self.attn_mask_type == "causal_bottom_right" or ( if self.attn_mask_type == "causal" and sq != sk:
self.attn_mask_type == "causal" and sq == sk return False # Fused causal kernel only support causal_bottom_right
): # fused causal softmax kernel
return True
if ( if (
sq % 4 == 0 # sq must be divisor of 4 sq % 4 == 0 # sq must be divisor of 4
...@@ -387,7 +385,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -387,7 +385,7 @@ class FusedScaleMaskSoftmax(nn.Module):
assert self.kvcache_max_seq >= seq_len_k assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else: else:
mask = _get_default_causal_mask(seq_len_q, seq_len_k) mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
mask_output = inp mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask": if mask is not None and self.attn_mask_type != "no_mask":
......
...@@ -142,9 +142,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -142,9 +142,11 @@ class TransformerLayer(torch.nn.Module):
sliding window size for local attention in encoder, where query at position i sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
no sliding window and "`causal`" mask specifically. Similar to no sliding window and causal mask specifically. Both `causal` and
:attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size` `causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
in `forward` as well. distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
:attr:`window_size` in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask` default = `no_mask`
type of attention mask passed into softmax operation for decoder. type of attention mask passed into softmax operation for decoder.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment