Unverified Commit 8e039fdc authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add cuDNN sliding window and set_deterministic_algorithm (#992)



* add cuDNN swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix SWA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add set_deterministic and minor fixes for swa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add AttentionParams
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change window_size to int64_t; fix swa/determinism tests; cache _attention_backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add window_size to get_backend; fix jax and paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; add set_deter to bwd_impl
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 tests due to determinism
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support matrix for SWA and bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add wording on window_size special cases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak on wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax assertion error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix wording
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* call bwd with deterministic=true for jax/paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism words in documentation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 166bb078
......@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "141fa8bd",
"id": "8ae3bc43",
"metadata": {},
"source": [
"# Attention Is All You Need!\n",
......@@ -22,7 +22,7 @@
},
{
"cell_type": "markdown",
"id": "09a60057",
"id": "47421c01",
"metadata": {},
"source": [
"## 1. Attention Backends\n",
......@@ -38,7 +38,7 @@
},
{
"cell_type": "markdown",
"id": "f387274e",
"id": "e52f60f0",
"metadata": {},
"source": [
"### 1.1 Flash vs. Non-Flash\n",
......@@ -58,7 +58,7 @@
},
{
"cell_type": "markdown",
"id": "f1389145",
"id": "bb909ac4",
"metadata": {},
"source": [
"### 1.2 flash-attention\n",
......@@ -97,7 +97,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "bbc5c73f",
"id": "9a380859",
"metadata": {},
"outputs": [],
"source": [
......@@ -113,7 +113,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "173638b6",
"id": "0584bb01",
"metadata": {},
"outputs": [
{
......@@ -140,7 +140,7 @@
},
{
"cell_type": "markdown",
"id": "0f62d2fa",
"id": "45e53fc9",
"metadata": {},
"source": [
"## 2. Backend Selection\n",
......@@ -160,7 +160,7 @@
},
{
"cell_type": "markdown",
"id": "86e16a2b",
"id": "6dfeade3",
"metadata": {},
"source": [
"### 2.1 Debug Information\n",
......@@ -177,7 +177,7 @@
},
{
"cell_type": "markdown",
"id": "e439434e",
"id": "7e3b7981",
"metadata": {},
"source": [
"The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
......@@ -186,7 +186,7 @@
{
"cell_type": "code",
"execution_count": 22,
"id": "9d002327",
"id": "961c51d4",
"metadata": {},
"outputs": [
{
......@@ -210,7 +210,7 @@
},
{
"cell_type": "markdown",
"id": "bbf1756c",
"id": "11bfbbd7",
"metadata": {},
"source": [
"To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
......@@ -219,7 +219,7 @@
{
"cell_type": "code",
"execution_count": 25,
"id": "66a2f34c",
"id": "162a2be1",
"metadata": {},
"outputs": [
{
......@@ -249,7 +249,7 @@
},
{
"cell_type": "markdown",
"id": "9f964732",
"id": "779a51e6",
"metadata": {},
"source": [
"### 2.2 User Control\n",
......@@ -274,11 +274,16 @@
"\n",
"Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
"```\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"Before cuDNN 9.0:\n",
" NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
" NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"\n",
"After cuDNN 9.0:\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path\n",
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
"```\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code> and <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> are only supported in PyTorch, not JAX or PaddlePaddle.\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"</div>\n",
"\n",
"### 2.3 Example Tests\n",
......@@ -290,7 +295,7 @@
},
{
"cell_type": "markdown",
"id": "3ad85b86",
"id": "ccd5650d",
"metadata": {},
"source": [
"## 3. Backend Support\n",
......@@ -311,7 +316,7 @@
},
{
"cell_type": "markdown",
"id": "37920af4",
"id": "8439b389",
"metadata": {},
"source": [
"### 3.1 QKV Layout\n",
......@@ -354,7 +359,7 @@
},
{
"cell_type": "markdown",
"id": "94c69fae",
"id": "0290f8e9",
"metadata": {},
"source": [
"### 3.2 Attention Mask\n",
......@@ -387,7 +392,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "4c87df64",
"id": "b1b7cdd4",
"metadata": {},
"outputs": [
{
......@@ -408,7 +413,7 @@
},
{
"cell_type": "markdown",
"id": "5ec0c75d",
"id": "e045c284",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n",
......@@ -432,7 +437,7 @@
},
{
"cell_type": "markdown",
"id": "3f8f6f1c",
"id": "8b8a4e40",
"metadata": {},
"source": [
"### 3.4 FP8 Attention\n",
......
......@@ -22,6 +22,9 @@ from transformer_engine.pytorch.attention import (
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
check_set_window_size,
AttentionParams,
_attention_backends,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -84,6 +87,7 @@ class ModelConfig:
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -100,6 +104,7 @@ class ModelConfig:
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
def _get_attention_backends(
......@@ -118,6 +123,8 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
......@@ -136,33 +143,10 @@ def _get_attention_backends(
fused_attn_backends = []
available_backends = None
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
......@@ -184,13 +168,18 @@ def _get_attention_backends(
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
......@@ -211,18 +200,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
ml = ~ml
return w, ml
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -251,15 +228,22 @@ def test_dot_product_attention(
pytest.skip("No need to test this layout for cross attention")
# Test backend availability
window_size = (2, 2) if swa else (-1, -1)
window_size = (-1, -1)
if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist())
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=window_size,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
flash_attn_supported = True
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
......@@ -268,9 +252,6 @@ def test_dot_product_attention(
is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
......@@ -278,12 +259,9 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if swa:
config.attn_mask_type = attn_mask_type
# FusedAttention backend
if fused_attn_supported:
......@@ -295,7 +273,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -308,7 +285,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -320,7 +296,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -334,7 +309,6 @@ def test_dot_product_attention(
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -499,11 +473,19 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
}
......@@ -622,7 +604,6 @@ def _run_dot_product_attention(
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
......@@ -637,6 +618,8 @@ def _run_dot_product_attention(
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -733,11 +716,6 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
......@@ -902,7 +880,7 @@ def _run_dot_product_attention(
q,
k,
v,
window_size=window_size,
window_size=config.window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
......@@ -1121,6 +1099,8 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
inp = torch.randn(
......@@ -1279,6 +1259,9 @@ def _rmse(a, b):
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
......@@ -1457,6 +1440,9 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
......@@ -1741,6 +1727,8 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
-100,
......@@ -1794,6 +1782,8 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
......
......@@ -72,7 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -116,7 +117,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0))) {
flag_m512 = true;
}
if ( // architecture
......@@ -165,7 +167,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) ||
(qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
(qkv_format == NVTE_QKV_Format::NVTE_BSHD)) &&
// sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -213,6 +227,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -242,9 +257,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type,
dropout, h, h, max_seqlen, max_seqlen, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -259,8 +274,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -286,7 +302,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
......@@ -317,9 +334,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type,
dropout, h, h, max_seqlen, max_seqlen, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -341,9 +358,10 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -375,6 +393,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
......@@ -409,9 +428,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -426,9 +445,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -454,7 +474,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -491,9 +512,9 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -517,10 +538,10 @@ void nvte_fused_attn_bwd_kvpacked(
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV,
input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -553,7 +574,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -579,9 +601,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -596,9 +618,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -625,7 +648,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
......@@ -655,9 +680,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -681,10 +706,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K,
input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
......@@ -50,8 +50,9 @@ namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
......@@ -63,6 +64,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -70,6 +75,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
try {
FADescriptor_v1 descriptor{b,
......@@ -86,6 +95,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout,
bias_type,
mask_type,
window_size_left,
window_size_right,
true,
tensorType,
tensorType};
......@@ -208,6 +220,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_options.set_sliding_window_length(window_size_left);
}
sdpa_options.set_alibi_mask(is_alibi);
if (is_bias) {
......@@ -367,7 +383,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b,
int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrKTranspose,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
......@@ -381,10 +398,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
try {
FADescriptor_v1 descriptor{b,
......@@ -401,6 +428,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout,
bias_type,
mask_type,
window_size_left,
window_size_right,
deterministic,
tensorType,
tensorType};
......@@ -552,6 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_backward_options.set_sliding_window_length(window_size_left);
}
if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
sdpa_backward_options.set_alibi_mask(is_alibi);
if (is_bias) {
......@@ -746,7 +784,8 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -827,9 +866,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
......@@ -850,6 +889,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
......@@ -902,11 +942,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -926,7 +966,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1011,10 +1052,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1035,11 +1077,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1089,11 +1132,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1110,17 +1153,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
}
}
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1193,10 +1234,11 @@ void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1217,9 +1259,10 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1260,11 +1303,11 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -21,13 +21,15 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
......@@ -37,7 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -46,31 +49,31 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......
......@@ -1687,6 +1687,9 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
0,
0,
true,
fwd_tensor_type,
fwd_tensor_type};
......@@ -1981,6 +1984,9 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
0,
0,
false,
fwd_tensor_type,
bwd_tensor_type};
......
......@@ -100,16 +100,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, bias_type, fwd_tensor_type,
bwd_tensor_type) <
dropoutProbability, layout, mask_type, window_size_left, window_size_right,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
};
......
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -135,22 +137,25 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim);
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -197,6 +202,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -206,6 +213,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
......@@ -248,6 +256,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -258,7 +269,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
......@@ -310,6 +322,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -321,6 +336,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
......@@ -369,6 +385,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -379,7 +398,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -435,6 +455,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -446,7 +468,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -499,6 +522,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -510,7 +536,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -162,5 +162,6 @@ class DelayedScaling:
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
f"reduce_amax={self.reduce_amax}"
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
......@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim);
head_dim, -1, -1);
return backend;
}
......@@ -154,22 +154,22 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
......@@ -258,7 +258,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -277,11 +277,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -294,7 +295,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
bias_type, mask_type, -1, -1, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -305,12 +306,13 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -373,13 +375,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -388,8 +390,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......@@ -399,8 +401,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -487,7 +489,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......@@ -509,13 +511,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -542,7 +544,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -577,7 +580,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
......
......@@ -134,7 +134,7 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim);
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1);
return fused_attention_backend;
}
......
......@@ -673,7 +673,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -687,7 +687,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -758,7 +758,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -769,7 +769,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -840,12 +840,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -855,12 +855,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -935,24 +935,24 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1042,7 +1042,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
......@@ -1058,7 +1058,7 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
......@@ -1140,7 +1140,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1152,7 +1152,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
This diff is collapsed.
......@@ -100,6 +100,7 @@ def fused_attn_fwd_qkvpacked(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed QKV input.
......@@ -152,6 +153,11 @@ def fused_attn_fwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -236,6 +242,7 @@ def fused_attn_fwd_qkvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens,
qkv,
qkv_dtype,
......@@ -282,6 +289,8 @@ def fused_attn_bwd_qkvpacked(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input.
......@@ -346,6 +355,13 @@ def fused_attn_bwd_qkvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -393,6 +409,8 @@ def fused_attn_bwd_qkvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens,
qkv,
o,
......@@ -441,6 +459,7 @@ def fused_attn_fwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed KV input.
......@@ -503,6 +522,11 @@ def fused_attn_fwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -588,6 +612,7 @@ def fused_attn_fwd_kvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -641,6 +666,8 @@ def fused_attn_bwd_kvpacked(
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
......@@ -716,6 +743,13 @@ def fused_attn_bwd_kvpacked(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -766,6 +800,8 @@ def fused_attn_bwd_kvpacked(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -818,6 +854,7 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -886,6 +923,11 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -971,6 +1013,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -1026,6 +1069,8 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
......@@ -1106,6 +1151,13 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours.
Returns
----------
......@@ -1158,6 +1210,8 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
q,
......
......@@ -14,30 +14,29 @@
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
......@@ -48,8 +47,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -61,10 +61,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......@@ -76,9 +76,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
......@@ -89,10 +90,10 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
......
......@@ -21,18 +21,18 @@ THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_shape = (sq, sk)
if matrix_shape not in _default_causal_mask:
diagonal_offset = sk - sq + 1
_default_causal_mask[matrix_shape] = torch.triu(
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
return _default_causal_mask[matrix_shape]
return _default_causal_mask[matrix_identifiers]
def _get_onnx_export_causal_mask(
......@@ -332,10 +332,8 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True
if self.attn_mask_type == "causal" and sq != sk:
return False # Fused causal kernel only support causal_bottom_right
if (
sq % 4 == 0 # sq must be divisor of 4
......@@ -387,7 +385,7 @@ class FusedScaleMaskSoftmax(nn.Module):
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(seq_len_q, seq_len_k)
mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
......
......@@ -142,9 +142,11 @@ class TransformerLayer(torch.nn.Module):
sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
no sliding window and "`causal`" mask specifically. Similar to
:attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
no sliding window and causal mask specifically. Both `causal` and
`causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
:attr:`window_size` in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment