Unverified Commit 32db3928 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Integrate cuDNN frontend v1 to fused attention (#497)



* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/paddle for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/pytorch lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify stride generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix and/or logic in get_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flag_max512 and test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove v.contiguous() since get_qkv_layout covers it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip fp8 tests for sm89
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask type to comma-separated list
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last two commits
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* integrate v1/pre-release-5
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup prerelease5 integration and fix FA2.1 commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout to 0 if not training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* testing bias/alibi and padding+causal; add alibi to unfused DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set flag_arb to false when non determinism is not allowed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* followup on prev commit; remove redundant python env var setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor tweaks for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prepare for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix determinism logic for fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix gpt_checkpointing/dpa_accuracy problem
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix some seg fault issues
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add failure notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of non-deter var for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix workspace size in bwd and uncomment bias test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_alibi and remove check_support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update tests status
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove workspace_opt from FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable arbitrary backend + post scale bias in Jax; waiting on PR 525
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up bhsd order
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove support for padding_causal + cross for max512
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change alibi bias to float32 for bias_1_4/5 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further clean up tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix definition of workspace limit when dbias is present
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further tweak DP_WORKSPACE_LIMIT definition
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disallow alibi+no_mask for sdpa flash and update alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable dbias for non-hopper archs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix layernorm lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remode unused arg for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove build dir in setup.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change selection logic to prefer fused attn on sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix distributed jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix h and s order in header
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn fe v1 branch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove manual setting of workopt path due to dbias after v1 update
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add post_scale_bias and alibi to sdpa flash support matrix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix support matrix in header files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move headers back to .cu and change seed/offset to int64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update Megatron commit in L1 test and remove all prints in fused attn test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix L1 Megatron test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 arg in L1 Megatron script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print only when debug flag is on
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing loading to avoid loading other tests results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent ff760a9d
Subproject commit 12f35fa2be5994c1106367cac2fba21457b064f4
Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987
......@@ -8,6 +8,7 @@ set -e
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9
git checkout bcce6f54e075e3c3374ea67adefe54f3f2da2b07
sed -i -e '1504,1505d' megatron/model/transformer.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py
python $TE_PATH/tests/pytorch/distributed/print_logs.py
......@@ -77,10 +77,11 @@ class TestDistributedSelfAttn:
is_training = True
scaling_factor = 1.0
_, seqlen, _, _, hidden = data_shape
_, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden):
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
......@@ -182,10 +183,11 @@ class TestDistributedCrossAttn:
is_training = True
scaling_factor = 1.0
_, seqlen, _, hidden = data_shape
_, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden):
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
......
......@@ -180,12 +180,14 @@ class TestSelfFusedAttn():
@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim):
num_heads_q, num_heads_kv, head_dim):
assert isinstance(backend, Backend)
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim):
attn_mask_type, dropout_probability,
num_heads_q, num_heads_kv,
s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
......@@ -197,6 +199,8 @@ class TestSelfFusedAttn():
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
num_heads_q=h,
num_heads_kv=h,
head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
......
......@@ -48,11 +48,8 @@ class TestGroupSharding(unittest.TestCase):
def _get_model_and_optimizer(self, model, stage):
if stage == 1:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=self.strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.01,
paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()),
fleet.get_hybrid_communicate_group(),
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
......
......@@ -634,9 +634,11 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=head_size,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
......@@ -762,9 +764,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
......@@ -940,9 +944,11 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
......@@ -952,6 +958,8 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
......
......@@ -688,9 +688,11 @@ class TestFusedAttn:
else "bshd_bs2hd"
)
fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size,
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
......@@ -774,9 +776,11 @@ class TestFusedAttn:
test self attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bs3hd",
......@@ -799,9 +803,11 @@ class TestFusedAttn:
test cross attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
num_heads=h,
num_gqa_groups=h,
q_seqlen=s_q,
kv_seqlen=s_kv,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
......@@ -825,9 +831,11 @@ class TestFusedAttn:
test flash attention forward + backward
"""
if not is_fused_attention_supported(
head_size=d,
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bs3hd",
......
......@@ -102,9 +102,11 @@ def set_random_seed(seed):
def get_fused_attention_backend(
head_size: int,
num_heads: int,
num_gqa_groups: int,
q_seqlen: int,
kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "bs3hd",
......@@ -125,6 +127,8 @@ def get_fused_attention_backend(
AttnBiasType[bias_type],
AttnMaskType[mask_type],
dropout,
num_heads,
num_gqa_groups,
q_seqlen,
kv_seqlen,
head_size,
......@@ -132,9 +136,11 @@ def get_fused_attention_backend(
def is_fused_attention_supported(
head_size: int,
num_heads: int,
num_gqa_groups: int,
q_seqlen: int,
kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "bs3hd",
......@@ -143,9 +149,11 @@ def is_fused_attention_supported(
) -> bool:
"""Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend(
head_size=head_size,
num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=dtype,
dropout=dropout,
qkv_layout=qkv_layout,
......
......@@ -81,7 +81,6 @@ options=" \
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \
--save-interval ${SAVE_INTERVAL} \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split ${SPLIT} \
--clip-grad ${CLIP_GRAD} \
--weight-decay ${WEIGHT_DECAY} \
......@@ -90,8 +89,6 @@ options=" \
--init-method-std ${INIT_METHOD_STD} \
--log-params-norm \
--log-num-zeros-in-grad \
--no-query-key-layer-scaling \
--DDP-impl local \
--transformer-impl ${TRANSFORMER_IMPL} \
--tensorboard-dir ${TENSORBOARD_DIR} \
--fp8-margin 0 \
......@@ -108,7 +105,7 @@ if [[ "$WGRAD_FUSION" == "False" ]]; then
fi
if [[ "$FP8" != "False" ]]; then
options+=" --fp8-${FP8}"
options+=" --fp8-format ${FP8}"
fi
if [[ "$DTYPE" != "fp32" ]]; then
......
......@@ -2,8 +2,10 @@
#
# See LICENSE for license information.
import functools
from importlib.metadata import version
import os
import math
from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
......@@ -26,6 +28,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd,
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import (
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
......@@ -36,231 +42,304 @@ from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
import transformer_engine_extensions as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
# Only run FP8 tests on H100.
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def _get_cudnn_version():
cudnn_version_encoded = ext.get_cudnn_version()
cudnn_major = cudnn_version_encoded // 1000
cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
return [cudnn_major, cudnn_minor, cudnn_patch]
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
_cudnn_version = _get_cudnn_version()
@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major, encoded_version = divmod(encoded_version, 1000)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
class ModelConfig:
def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
dropout_p, attn_mask_type,
self,
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
num_layers: int = 1,
):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim
assert (hidden_size == num_attention_heads * head_dim
), """hidden_size must be = num_heads x head_dim."""
self.seq_len = seq_len
self.hidden_size = num_heads * head_dim
self.hidden_size_kv = num_gqa_groups * head_dim
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 32]
model_configs_lean = {
"test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
"test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
}
param_types_lean = [torch.bfloat16]
batch_sizes_lean = [2]
self.attn_bias_type = attn_bias_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
def _is_fused_attention_supported(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout: str = "sbh3d",
bias_type: str = "no_bias",
) -> bool:
) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
"""Check if FusedAttention supports a model configuration"""
backends = []
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[bias_type],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.seq_len,
config.seq_len,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
return backend != FusedAttnBackend["No_Backend"]
def _is_flash_attention_supported(bias_type: str = "no_bias") -> bool:
if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.cache
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
return False
if bias_type != "no_bias":
if config.attn_bias_type != "no_bias":
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
return False
return True
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with different backends"""
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout):
"""Test DotProductAttention module"""
# Get configs
config = model_configs[model]
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model]
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd"
else:
qkv_layout = "sbhd_sb2hd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip(
"No need to test this layout for cross attention"
)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(
config,
dtype,
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
unfused_attn_supported = _is_unfused_attention_supported(config)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"UnfusedDotProductAttention",
ckpt_attn,
bias_type,
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt,
)
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"FusedAttention",
ckpt_attn,
bias_type,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
bs,
config,
"FlashAttention",
ckpt_attn,
bias_type,
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt,
)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype=dtype).cuda()
else:
bias = None
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self"
).to(dtype=dtype).cuda()
)
if unfused_attn_supported and fused_attn_supported:
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None)
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v,
qkv_format='sbhd',
cu_seqlens_q = cu_seqlens,
cu_seqlens_kv = cu_seqlens,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
op.backward(op_grad)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
return op, inp.grad
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
......@@ -269,54 +348,39 @@ qkv_layouts = [
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]
@pytest.mark.skipif(
_cudnn_version < [8,9,5], reason="cuDNN 8.9.5+ is required.")
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout)
# Get configs
config = model_configs_lean[model]
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(config, dtype)
flash_attn_supported = _is_flash_attention_supported()
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], **tols)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
......@@ -325,122 +389,193 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
dim_to_num = {'b': bs,
's': config.seq_len,
'h': config.num_attention_heads,
'd': config.head_dim,
't': bs * config.seq_len,
'3': 3,
'2': 2}
# Create seqlens
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
if "padding" in config.attn_mask_type or qkv_format == 'thd':
if config.attn_type == 'self':
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
seqlens_kv = seqlens_q
if config.attn_type == 'cross':
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
dtype=torch.int32, device="cuda")
else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
if config.attn_type == 'self':
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == 'cross':
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
[True]*seqlens_kv[i] + [False]*(config.max_seqlen_kv-seqlens_kv[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
# Create input tensors
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
tensor_shape = [dim_to_num[j] for j in layout]
tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda()
layout = '_'.join(layout)
if i == 0:
layout = layout.replace('s', 'sq')
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim,l in enumerate(layout):
for dim, l in enumerate(layout.split('_')):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor]
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad=True
inp[i].requires_grad = True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd'
op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()
# Create output gradient
qkv_format_kv = '_'.join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
# Create bias
if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model
block = (
DotProductAttention(
config.num_attention_heads,
config.num_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
)
if qkv_format != 'thd':
op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
else:
cu_seqlens_q = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[0].device)
cu_seqlens_kv = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[1].device)
op = block(inp[0], inp[1], inp[2],
# Run a forward and backward pass
out = block(inp[0], inp[1], inp[2],
attention_mask=attention_mask,
qkv_format=qkv_format,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv)
op.backward(op_grad)
return op, (inp[0].grad, inp[1].grad, inp[2].grad)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
fast_zero_fill=True)
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
@pytest.mark.parametrize("fused_qkv_params", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
"""Test TransformerLayer module"""
# Get configs
config = model_configs_lean[model]
config = model_configs[model]
tols = dict(atol=5e-1, rtol=5e-2)
# TODO @cyanguwa: Handle test cases more cleanly
if config.hidden_size > 1024:
pytest.skip(
"Tolerances for test_transformer_layer are intended for small test cases"
)
workspace_opt = True
# Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported(
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config,
dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
bias_type=bias_type,
)
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
if not (fused_attn_supported or flash_attn_supported):
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config"
)
flash_attn_supported = _is_flash_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"UnfusedDotProductAttention",
bias_type,
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
......@@ -449,32 +584,89 @@ def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"FusedAttention",
bias_type,
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype,
bs,
config,
"FlashAttention",
bias_type,
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
if unfused_attn_supported and fused_attn_supported:
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported:
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
def test_te_layer_misc(dtype, model_configs, model):
"""Test TransformerLayer module with miscellanous settings"""
ckpt_attn = True
qkv_format = "bshd"
fused_qkv_params = True
RoPE = True
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA"""
def find_factors(x):
f = []
for i in range(2, x + 1):
if x % i == 0:
f.append(i)
return f
ckpt_attn = True
qkv_format = "bshd"
fused_qkv_params = True
RoPE = True
config = model_configs[model]
num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......@@ -483,14 +675,27 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
# Create input tensor
inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
dtype=dtype, device="cuda", requires_grad = True)
# Create seqlens
if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -500,22 +705,44 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype=dtype).cuda()
# Create bias
if config.attn_bias_type == 'no_bias':
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
elif config.attn_bias_type == 'alibi':
if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
config.attn_bias_type = 'post_scale_bias'
n = 2 ** math.floor(math.log2(config.num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
d = c - torch.transpose(c, 0, 1)
bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
for i in range(config.num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=dtype, device="cuda")
else:
bias = None
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype)
rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")
# Set up model
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
......@@ -523,13 +750,14 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.seq_len,
micro_batch_size=bs,
seq_length=config.max_seqlen_q,
micro_batch_size=config.batch_size,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -542,169 +770,65 @@ def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_para
ub_tp_comm_overlap=False,
bias=True,
)
.to(dtype=dtype)
.cuda()
.to(dtype=dtype, device="cuda")
)
num_iters = 5
for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type,
# Run a forward and backward pass
out = block(inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=bias_type,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias)
loss = op.sum()
loss = out.sum()
loss.backward()
return op, inp.grad
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention or UnfusedDotProductAttention backend"""
config = model_configs_lean[model]
def find_factors(x):
f = []
for i in range(1, x + 1):
if x % i == 0:
f.append(i)
return f
# Skip if only unfused backend is supported
if not (_flash_attn_2_available and _is_flash_attention_supported()):
pytest.skip("FlashAttention does not support this model config")
num_querys_per_gqa_group = find_factors(config.num_attention_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-2
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
tp_group=None,
tp_size= 1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.seq_len,
micro_batch_size=bs,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
)
.to(dtype=dtype)
.cuda()
)
op = block(inp, self_attn_mask_type=config.attn_mask_type)
op.backward(op_grad)
return out, inp.grad
return op, inp.grad
model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
"""Test FP8 dot-product attention with different backends
def test_dpa_fp8(dtype, model):
"""Test FP8 dot product attention
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from
cpp_extensions. UnfusedDotProductAttention uses plain PyTorch
operations.
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions,
and UnfusedDotProductAttention uses plain PyTorch operations in FP16
and converts inputs/outputs from/to FP8.
"""
config = model_configs_fp8[model]
# Skip if not supported
if not _is_fused_attention_supported(config, dtype):
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype)
if not fused_attn_supported:
pytest.skip("FusedAttention does not support this model config")
# Run dot-product attention with different backends
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype,
bs,
config,
"FusedAttention"
)
dtype, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype,
bs,
config,
"UnfusedDotProductAttention",
)
dtype, config, "UnfusedDotProductAttention")
# Check that results match
tols = dict(atol=2.5e-2, rtol=2.5e-2)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, bs, config, backend):
def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -715,17 +839,16 @@ def _run_dpa_fp8(dtype, bs, config, backend):
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype, device="cuda", requires_grad=True)
seqlens = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
torch.save(op_grad, 'op_grad.pt')
out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype, device="cuda")
torch.save(out_grad, 'out_grad.pt')
fp8_recipe = recipe.DelayedScaling(
margin=0,
......@@ -735,17 +858,21 @@ def _run_dpa_fp8(dtype, bs, config, backend):
amax_compute_algo="most_recent",
)
dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len)
op.backward(op_grad)
out = dpa(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
context = torch.load("ctx.pt")
dqkv = torch.load('dqkv.pt')
return (context.view(bs, config.seq_len, -1).transpose(0,1),
dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())
return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1),
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, bs, config, backend):
def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
are converted from/to FP8"""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......@@ -754,13 +881,20 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
inp = torch.load('qkv.pt').to(device="cuda")
inp.requires_grad = True
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
out_grad = torch.load('out_grad.pt').to(device="cuda").view(
config.batch_size, config.max_seqlen_q, -1).transpose(0,1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -771,7 +905,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
block = (
DotProductAttention(
config.num_attention_heads,
config.num_heads,
config.head_dim,
attention_dropout=config.dropout_p,
sequence_parallel=False,
......@@ -780,16 +914,17 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
tp_group=None,
layer_number=1,
attention_type="self"
).to(dtype=dtype).cuda()
).to(dtype=dtype, device="cuda")
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad)
out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad)
return out, inp.grad
return op, inp.grad
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
......@@ -812,7 +947,7 @@ class _dpa_fp8(torch.autograd.Function):
qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor,
cu_seqlens: torch.Tensor,
num_attention_heads: int,
num_heads: int,
p_dropout: float,
max_s: int,
fast_zero_fill: bool,
......@@ -823,7 +958,7 @@ class _dpa_fp8(torch.autograd.Function):
assert inp.dim() == 2
in_features = qkv_weight.shape[-1]
h = num_attention_heads
h = num_heads
d = in_features // h
b = cu_seqlens.numel() - 1
is_nl = False
......@@ -921,7 +1056,7 @@ class _dpa_fp8(torch.autograd.Function):
ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl
ctx.hidden_size = in_features
ctx.num_attention_heads = num_attention_heads
ctx.num_heads = num_heads
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16)
......@@ -1050,7 +1185,7 @@ class DPA_FP8(TransformerEngineBaseModule):
params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_attention_heads
self.h = config.num_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.fast_zero_fill = True
......
......@@ -508,6 +508,7 @@ def _test_e2e_checkpointing_get_model(config, dtype):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return (
TransformerLayer(
config.hidden_size,
......@@ -524,7 +525,6 @@ def _test_e2e_checkpointing_get_model(config, dtype):
params_dtype=dtype,
)
.cuda()
.eval()
)
......@@ -559,9 +559,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad:
param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path))
reset_rng_states()
for p in block.parameters():
if p.requires_grad:
......@@ -815,21 +820,19 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention(
config.num_attention_heads,
config.embed,
attention_dropout=0.1, # dropout
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_dpa = (
TorchDotProductAttention(
config.embed,
0.1, # dropout
0.0, # dropout
)
.to(dtype=dtype)
.cuda()
.eval()
)
te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
......
......@@ -11,6 +11,7 @@
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
......@@ -18,7 +19,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_3HD;
case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_BSH3D:
......@@ -27,7 +27,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
......@@ -36,7 +35,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
default:
NVTE_ERROR("qkv_layout not supported!");
......@@ -63,9 +61,6 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
......@@ -79,8 +74,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
float dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -91,56 +88,66 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD))) {
#if (CUDNN_VERSION >= 8900)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
#endif
}
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (max_seqlen_q <= 512)
&& (max_seqlen_kv <= 512)
&& (head_dim == 64)
&& (num_attn_heads == num_gqa_groups)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
&& max_seqlen_q == max_seqlen_kv)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true;
}
if (
#if (CUDNN_VERSION >= 8903)
(sm_arch_ >= 80)
#else
(sm_arch_ == 80 || sm_arch_ == 90)
#endif
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
|| (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
&& (max_seqlen_q % 64 == 0)
&& (max_seqlen_kv % 64 == 0)
&& ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
|| (cudnn_runtime_version >= 8907))
&& ((head_dim <= 128) && (head_dim % 8 == 0))
&& ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| ((cudnn_runtime_version >= 8906 && sm_arch_ == 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK)
|| bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)))
&& ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) {
|| ((cudnn_runtime_version >= 8906)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
|| attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))
&& (!(cudnn_runtime_version >= 8906
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
&& bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
......@@ -148,34 +155,32 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
if (flag_arb == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
} else if ((flag_arb == false) && (flag_m512 == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
}
int env_backend = static_cast<int>(backend);
env_backend = transformer_engine::getenv<int>("NVTE_FUSED_ATTN_BACKEND", env_backend);
if (((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen))
&& flag_m512)
|| ((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))
&& flag_arb)) {
backend = static_cast<NVTE_Fused_Attn_Backend>(env_backend);
}
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND");
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
#if (CUDNN_VERSION < 8901)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
if (cudnn_runtime_version < 8901
&& backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
#if (CUDNN_VERSION < 8900)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
if (cudnn_runtime_version < 8900
&& backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
......@@ -208,10 +213,17 @@ void nvte_fused_attn_fwd_qkvpacked(
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[ndim - 2];
size_t h = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
h = input_QKV->data.shape[ndim - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
h = input_QKV->data.shape[ndim - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
......@@ -221,12 +233,12 @@ void nvte_fused_attn_fwd_qkvpacked(
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
dropout, h, h, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
......@@ -239,7 +251,7 @@ void nvte_fused_attn_fwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
......@@ -253,7 +265,7 @@ void nvte_fused_attn_fwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_CTX_Tensors,
......@@ -297,10 +309,17 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[ndim - 2];
size_t h = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
h = input_QKV->data.shape[ndim - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
h = input_QKV->data.shape[ndim - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
......@@ -310,13 +329,13 @@ void nvte_fused_attn_bwd_qkvpacked(
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
dropout, h, h, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO,
output_S,
......@@ -329,11 +348,17 @@ void nvte_fused_attn_bwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
input_QKV, input_O, input_dO, input_Bias,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
......@@ -350,7 +375,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(
b, max_seqlen, h, d,
b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO,
input_M, input_ZInv,
......@@ -395,12 +420,20 @@ void nvte_fused_attn_fwd_kvpacked(
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
auto ndim = input_Q->data.shape.size();
size_t h_q = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto ndim_kv = input_KV->data.shape.size();
size_t h_kv = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
h_kv = input_KV->data.shape[ndim_kv - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
h_kv = input_KV->data.shape[ndim_kv - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -410,12 +443,12 @@ void nvte_fused_attn_fwd_kvpacked(
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
......@@ -426,10 +459,19 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
......@@ -471,12 +513,20 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
auto ndim = input_Q->data.shape.size();
size_t h_q = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto ndim_kv = input_KV->data.shape.size();
size_t h_kv = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
h_kv = input_KV->data.shape[ndim_kv - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
h_kv = input_KV->data.shape[ndim_kv - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -486,13 +536,13 @@ void nvte_fused_attn_bwd_kvpacked(
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_dO,
output_S,
......@@ -503,10 +553,29 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
#if (CUDNN_VERSION >= 8903)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO, input_Bias,
output_S,
output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
......@@ -546,7 +615,8 @@ void nvte_fused_attn_fwd(
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
......@@ -557,12 +627,12 @@ void nvte_fused_attn_fwd(
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
......@@ -575,7 +645,7 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
......@@ -589,7 +659,7 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors,
......@@ -644,7 +714,8 @@ void nvte_fused_attn_bwd(
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
......@@ -655,13 +726,13 @@ void nvte_fused_attn_bwd(
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_dO,
output_S,
......@@ -674,11 +745,17 @@ void nvte_fused_attn_bwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
input_Q, input_K, input_V, input_O, input_dO, input_Bias,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
......@@ -695,7 +772,7 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv,
......
......@@ -9,6 +9,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <map>
#include <vector>
......@@ -46,1476 +47,613 @@
namespace transformer_engine {
namespace fused_attn {
static cudnn_frontend::Tensor
createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
const cudnn_frontend::Tensor& sTensor,
std::vector<cudnn_frontend::Operation>* ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto scaleTensor = tensor_create(
tensorType, S_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto sScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 2000, s_dim,
s_stride, true, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a scale node
auto scale_op = binary_pw_op_create(sTensor, scaleTensor, sScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return sScaleTensor;
}
static cudnn_frontend::Tensor
createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_dim, k_stride, false, false); // is virtual
// first GEMM output
auto sTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, s_dim, s_stride, true, false); // is virtual
// Define the matmul 1 desc
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// Create a matmul 1 node
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(sTensor)
.setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op1 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op1));
return sTensor;
}
static cudnn_frontend::Tensor
createPaddingMask(int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true);
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false);
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false);
// less than row output
auto lessThanRowTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false);
// less than column output
auto lessThanColTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false);
// padding mask (lessthanRow && lessthanCol)
auto paddingMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false);
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false);
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the less than comparison for row descriptor
auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for row Node.
auto lessThanRow_op = binary_pw_op_create(
rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
// Define the less than comparison for column descriptor
auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for col Node.
auto lessThanCol_op = binary_pw_op_create(
columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
// Define the less than comparison for column descriptor
auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
// Create a and node for combining lessThanRow and lessThanCol
auto paddingMaskAnd_op = binary_pw_op_create(
lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
/////////////////// Apply the mask //////////////////////////
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select Node.
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(lessThanRow_op));
ops->push_back(std::move(lessThanCol_op));
ops->push_back(std::move(paddingMaskAnd_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim,
maskVal_stride, false, true); // is by value
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// create causal mask (row >= col)
auto causalMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 107, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the greater than equal to comparison descriptor
auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE);
// Create a greater than equal to node
auto rowGreaterCol_op = binary_pw_op_create(
rowIndexTensor, columnIndexTensor, causalMaskTensor, rowGreaterColDesc);
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select node
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor,
causalMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(rowGreaterCol_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createSoftmaxForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, bool isTraining,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& sAfterMaskTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
// max (x)
auto afterMaxReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// x - max(x)
auto afterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x)))
auto afterAddReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// log (sum (e^(x - max(x))))
auto afterLogLTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 154, afterReduction_dim,
afterReduction_stride, true, false);
// M + log (sum (e^(x - max(x))))
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, afterReduction_dim,
afterReduction_stride, !isTraining, false);
// not virtual if training is true, virtual if training is false
// divide (e/ sum(e))
auto afterSoftmaxTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 156)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(sAfterMaskTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract node
auto subtract_op = binary_pw_op_create(
sAfterMaskTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent node
auto exponent_op = unary_pw_op_create(
afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(afterAddReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Create log descriptor
auto logDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_LOG);
// Create log node
auto log_op = unary_pw_op_create(afterAddReductionTensor, afterLogLTensor, logDesc);
// Create add descriptor
auto addDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD);
// Create add node
auto add_op = binary_pw_op_create(
afterMaxReductionTensor, afterLogLTensor,
softmaxStatsTensor, addDesc);
// Define the division descriptor
auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV);
// Create a subtract node
auto division_op = binary_pw_op_create(
afterExponentTensor, afterAddReductionTensor,
afterSoftmaxTensor, divisionDesc);
ops->push_back(std::move(reductionMax_op));
ops->push_back(std::move(subtract_op));
ops->push_back(std::move(exponent_op));
ops->push_back(std::move(reductionAdd_op));
ops->push_back(std::move(log_op));
ops->push_back(std::move(add_op));
ops->push_back(std::move(division_op));
return afterSoftmaxTensor;
}
static cudnn_frontend::Tensor
createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// mask for the dropout
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor
createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static void
createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
cudnn_frontend::Tensor const &afterScaleDropoutTensor) {
NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one");
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
// Define the matmul 2 desc
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
// Create a matmul 2 node
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(afterScaleDropoutTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op2));
}
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type,
void *devPtrQ, void *devPtrK, void *devPtrV,
NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) {
dropout_probability = 0.0f;
}
// also known as variable_sequence_length
bool padding_aware = (mask_type == NVTE_PADDING_MASK) ||
(mask_type == NVTE_PADDING_CAUSAL_MASK);
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
is_training, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
mask_type, tensorType,
false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache;
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
auto graph = it->second;
return graph;
}
// otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Q * K^T
auto sTensor = createQKBMM(
b, h, s_q, s_kv, d, padding_aware, layout, tensorType, &ops);
// Q * K^T * bmmScale
auto sScaleTensor = createScale(
b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops);
auto& sAfterMaskTensor = sScaleTensor;
if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) {
sAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(tensorType)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_attributes
scaled_dot_product_flash_attention_options;
scaled_dot_product_flash_attention_options =
fe::graph::Scaled_dot_product_flash_attention_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_options.set_bias(bias);
}
if (padding_aware) {
sAfterMaskTensor = createPaddingMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor);
if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}
NVTE_CHECK(dropout_probability != 1.0f,
"Dropout probability cannot be 1.0");
auto softmax_output = createSoftmaxForward(
b, h, s_q, s_kv, is_training, &ops, sAfterMaskTensor);
if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
scaled_dot_product_flash_attention_options.set_dropout(
dropout_probability, dropout_seed, dropout_offset);
}
// Dropout(softmax)
auto dropout_output = createDropoutForward(
b, h, s_q, s_kv, d,
dropout_probability, tensorType, &ops, softmax_output);
createSVBMM(b, h, s_q, s_kv, d, padding_aware,
layout, tensorType, &ops, dropout_output);
auto [O, Stats] = mha_graph->scaled_dot_product_flash_attention(
Q, K, V, scaled_dot_product_flash_attention_options);
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig,
filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto plan = get_plan(fmha_fprop_cache, descriptor);
auto [mha_graph, Q, K, V, attn_scale, O, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_fprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) {
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// Prepare actual seqlen
// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ},
{K, devPtrK},
{V, devPtrV},
{attn_scale, &scaling_factor},
{O, devPtrO}};
if (is_training) {
variant_pack[Stats] = devPtrSoftmaxStats;
}
if (is_bias) {
variant_pack[bias] = devPtrBias;
}
if (is_padding) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
if (padding_aware) {
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
}
std::set<std::pair<uint64_t, void*>> data_ptrs;
// Add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+30f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void*>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_SEQLEN_ID, devActualSeqlenK));
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
// If training mode, we write out softmax stats
if (is_training) {
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
mha_graph->execute(handle, variant_pack, workspace);
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrKTranspose,
void* devPtrVTranspose, void* devPtrO, void* devPtrSoftmaxStats,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
void* devPtrO, void* devPtrSoftmaxStats, void* devPtrBias,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) {
try {
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// also known as variable_sequence_length
bool padding_aware = (mask_type == NVTE_PADDING_MASK) ||
(mask_type == NVTE_PADDING_CAUSAL_MASK);
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
true, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
mask_type, tensorType,
use_workspace_opt};
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache;
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
return it->second;
}
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, q_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_transpose_dim[4] = {b, h, d, s_kv};
int64_t k_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t v_transpose_dim[4] = {b, h, d, s_kv};
int64_t v_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, v_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose);
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, p_stride,
layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t p_transpose_stride[4];
p_transpose_stride[0] = p_stride[0];
p_transpose_stride[1] = p_stride[1];
p_transpose_stride[2] = p_stride[3];
p_transpose_stride[3] = p_stride[2];
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t dqAccum_dim[4] = {b, h, s_q, d};
int64_t dqAccum_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
/*******************************************************************************
* Dot product dO * O */
// output and gradient of the output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
auto dOTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto dotProductTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID, o_dim,
o_stride, true, false); // is virtual
// Create pointwise mul
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// do * O
auto dotProductOp = binary_pw_op_create(
dOTensor, oTensor, dotProductTensor, multiplyDesc);
ops.push_back(std::move(dotProductOp));
/*******************************************************************************
* Reduction(dO * O) */
int64_t reduction_dim[4] = {b, h, s_q, 1};
int64_t reduction_stride[4] = {h * s_q, s_q, 1, 1};
// reduction(dO * O)
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
ops.push_back(std::move(reductionAdd_op));
/*******************************************************************************
* reduction(dO * O) * scale prob -> softmaxSum */
auto softmaxSumTensor = tensor_create(
CUDNN_DATA_FLOAT, S_SUM_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto scaleProbTensor = tensor_create(
CUDNN_DATA_FLOAT, SCALE_PROB, scale_dim,
scale_stride, false, true); // not virtual
auto softmaxSumOp = binary_pw_op_create(
afterReductionTensor, scaleProbTensor,
softmaxSumTensor, multiplyDesc);
ops.push_back(std::move(softmaxSumOp));
/*******************************************************************************
* Q @ K.T -> P */
// Inputs from fprop
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_transpose_dim,
k_transpose_stride, false, false);
auto pTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2, p_dim,
p_stride, true, false); // is virtual
// matmul to calculate dvTensor
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(pTensor)
.setmatmulDesc(matmul_0_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op0 = matmul_op_builder.build();
ops.push_back(std::move(matmul_op0));
/*******************************************************************************
* P * bmmScale -> pAfterScale */
auto bmmScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim,
scale_stride, false, true); // not virtual and by value
auto pAfterScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2000, p_dim,
p_stride, true, false); // virtual
auto scaleOp = binary_pw_op_create(
pTensor, bmmScaleTensor, pAfterScaleTensor, multiplyDesc);
ops.push_back(std::move(scaleOp));
/*******************************************************************************
* Causal masking -> pAfterMaskTensor */
auto& pAfterMaskTensor = pAfterScaleTensor;
if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) {
pAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
auto graph = it->second;
return graph;
}
if (padding_aware) {
pAfterMaskTensor = createPaddingMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor);
}
/*******************************************************************************
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
auto pAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 3, p_dim,
p_stride, true, false); // is virtual
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto subtract_op = binary_pw_op_create(
pAfterMaskTensor, softmaxStatsTensor,
pAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op));
/*******************************************************************************
* e^(pAfterSubtract) -> pAfterSoftmax */
auto pAfterSoftmaxTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 4, p_dim,
p_stride, true, false); // is virtual
auto expDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
auto exp_op = unary_pw_op_create(
pAfterSubtractTensor, pAfterSoftmaxTensor, expDesc);
ops.push_back(std::move(exp_op));
/*******************************************************************************
* Dropout -> afterScaleDropout */
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 5, p_dim,
p_stride, true, false); // is virtual
auto afterScaleDropoutTensor = createDropoutBackward(
b, h, s_q, s_kv, d, dropout_probability, tensorType,
&ops, pAfterSoftmaxTensor, dropoutMaskTensor);
/*******************************************************************************
* afterScaleDropout -> sTransposeTensor */
auto sTransposeTensor = tensor_create(
tensorType, VIRTUAL_ID + 6, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(afterScaleDropoutTensor)
.setyDesc(sTransposeTensor)
.build();
ops.push_back(std::move(reshape_op));
// Outputs of bprop
int64_t dq_dim[4] = {b, h, s_q, d};
int64_t dq_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dq_stride,
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(tensorType)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t dk_dim[4] = {b, h, s_kv, d};
int64_t dk_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dk_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
int64_t dv_dim[4] = {b, h, s_kv, d};
int64_t dv_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dv_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dq_dim, dq_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dk_dim, dk_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dv_dim, dv_stride, false, false);
// not virtual
/*******************************************************************************
* sTransposeTensor @ dO -> dV */
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op1_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op1_builder.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor)
.setcMatDesc(dVTensor)
.setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op1_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor);
}
auto matmul_op1 = matmul_op1_builder.build();
ops.push_back(std::move(matmul_op1));
/*******************************************************************************
* dO @ V.T -> dS */
auto vTransposeTensor = tensor_create(
tensorType, V_ID, v_transpose_dim,
v_transpose_stride, false, false);
auto dSTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 7, p_dim,
p_stride, true, false); // is virtual
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op2_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op2_builder.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor)
.setcMatDesc(dSTensor)
.setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op2_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op2_builder.build();
ops.push_back(std::move(matmul_op2));
/*******************************************************************************
* dS * dropoutMask -> dSAfterDropout */
auto dSAfterDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 8, p_dim,
p_stride, true, false); // is virtual
auto multiply_op = binary_pw_op_create(
dSTensor, dropoutMaskTensor,
dSAfterDropoutTensor, multiplyDesc);
ops.push_back(std::move(multiply_op));
/*******************************************************************************
* dSAfterDropout - softmaxSum -> dsAfterSubtract */
auto dsAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 9, p_dim,
p_stride, true, false); // is virtual
auto subtract_op2 = binary_pw_op_create(
dSAfterDropoutTensor, softmaxSumTensor,
dsAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op2));
/*******************************************************************************
* dsAfterSubtract * afterSoftmax -> dP */
auto dPTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 10, p_dim,
p_stride, true, false); // is virtual
auto multiply_op2 = binary_pw_op_create(
dsAfterSubtractTensor, pAfterSoftmaxTensor,
dPTensor, multiplyDesc);
ops.push_back(std::move(multiply_op2));
/*******************************************************************************
* dP * scaleDropout -> dPAfterDropoutScale */
auto dPAfterDropoutScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 11, p_dim,
p_stride, true, false); // is virtual
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto multiply_op3 = binary_pw_op_create(
dPTensor, scaleDropoutTensor,
dPAfterDropoutScaleTensor, multiplyDesc);
ops.push_back(std::move(multiply_op3));
/*******************************************************************************
* dPAfterDropoutScale * bmmScale -> dPScaledTensor */
auto dPScaledTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 12, p_dim,
p_stride, true, false); // is virtual
auto multiply_op4 = binary_pw_op_create(
dPAfterDropoutScaleTensor, bmmScaleTensor,
dPScaledTensor, multiplyDesc);
ops.push_back(std::move(multiply_op4));
/*******************************************************************************
* K.T -> K */
int64_t kDim[4] = {b, h, s_kv, d};
int64_t kStride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, kStride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
auto kTensor = tensor_create(
tensorType, VIRTUAL_ID + 13, kDim,
kStride, true, false); // is virtual
auto reshape_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(kTransposeTensor)
.setyDesc(kTensor)
.build();
ops.push_back(std::move(reshape_op2));
/*******************************************************************************
* dP @ K -> dqAccumTensor / dqTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqAccum_dim)
.setStride(4, dqAccum_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(false)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op3_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op3_builder.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor)
.setmatmulDesc(matmul_3_Desc);
if (use_workspace_opt) {
matmul_op3_builder.setcMatDesc(dQTensor);
} else {
matmul_op3_builder.setcMatDesc(dqAccumTensor);
}
if (padding_aware) {
matmul_op3_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
}
auto matmul_op3 = matmul_op3_builder.build();
ops.push_back(std::move(matmul_op3));
/*******************************************************************************
* dP.T @ Q -> dK */
auto dPTransposeTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 14, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dPScaledTensor)
.setyDesc(dPTransposeTensor)
.build();
ops.push_back(std::move(reshape_op3));
auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op4_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op4_builder.setaMatDesc(dPTransposeTensor)
.setbMatDesc(qTensor)
.setcMatDesc(dKTensor)
.setmatmulDesc(matmul_4_Desc);
if (padding_aware) {
matmul_op4_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor);
}
auto matmul_op4 = matmul_op4_builder.build();
ops.push_back(std::move(matmul_op4));
/*******************************************************************************
* dqAccumTensor @ identity -> dqTensor */
if (!use_workspace_opt) {
auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY);
auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc);
ops.push_back(std::move(identity_op));
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_backward_attributes
scaled_dot_product_flash_attention_backward_options;
scaled_dot_product_flash_attention_backward_options =
fe::graph::Scaled_dot_product_flash_attention_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_backward_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_backward_options.set_bias(bias);
scaled_dot_product_flash_attention_backward_options.set_dbias(dBias);
}
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_backward_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr, CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_bprop: No config returned by the heuristics");
if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
scaled_dot_product_flash_attention_backward_options.set_dropout(
dropout_probability, dropout_seed, dropout_offset);
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
auto [dQ, dK, dV] = mha_graph->scaled_dot_product_flash_attention_backward(
q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options);
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes> > // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ?
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto plan = get_plan(fmha_bprop_cache, descriptor);
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_bprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = use_workspace_opt ? 0 : b * s_q * h * d * sizeof(float);
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size
+ dqAccum_workspace_size + actual_seqlen_workspace_size;
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size;
if (!use_workspace_opt) {
NVTE_CHECK_CUDA(cudaMemsetAsync(
devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
// build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ},
{k, devPtrKTranspose},
{v, devPtrVTranspose},
{o, devPtrO},
{dO, devPtrdO},
{stats, devPtrSoftmaxStats},
{attn_scale, &scaling_factor},
{dQ, devPtrdQ},
{dK, devPtrdK},
{dV, devPtrdV},
};
if (is_bias) {
variant_pack[bias] = devPtrBias;
variant_pack[dBias] = devPtrdBias;
}
if (is_padding) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ =
static_cast<int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+31f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
if (!use_workspace_opt) {
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator));
}
data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV));
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrKTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrVTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(dO_ID, devPtrdO));
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
data_ptrs.insert(std::pair<uint64_t, void*>(S_SUM_ID, devPtrSoftmaxSum));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
float scaleProb = 1.0f - dropout_probability;
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(SCALE_PROB, &scaleProb));
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
mha_graph->execute(handle, variant_pack, workspace);
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
const DType QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, 1};
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrCuSeqlens, devPtrCuSeqlens,
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1532,29 +670,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
}
}
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type,
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_heads,
size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = 2 * num_head * head_dim;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdBias = output_dBias->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
......@@ -1563,50 +711,208 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
bool use_workspace_opt = false;
#if (CUDNN_VERSION >= 8905)
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
if (sm_arch_ >= 90) {
// quick estimate of dp workspace size
size_t max_seqlen_div_up_q = ((max_seqlen + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(qkv_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
#if (CUDNN_VERSION < 8906)
}
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
#endif
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlens, devPtrCuSeqlens,
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt);
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdKV) + stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1624,8 +930,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
}
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, bool is_training,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
......@@ -1640,22 +946,49 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrBias = input_Bias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen_q, 1};
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1664,18 +997,18 @@ void fused_attn_arbitrary_seqlen_fwd(
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1692,13 +1025,14 @@ void fused_attn_arbitrary_seqlen_fwd(
}
}
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
......@@ -1712,57 +1046,37 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
void *devPtrV = input_V->data.dptr;
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0;
bool use_workspace_opt = false;
#if (CUDNN_VERSION >= 8905)
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
if (sm_arch_ >= 90) {
// quick estimate of dp workspace size
size_t max_seqlen_div_up_q = ((max_seqlen_q + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen_kv + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
#if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
#endif
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt);
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -12,14 +12,13 @@
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
......@@ -28,20 +27,47 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_size, bool is_training,
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
......@@ -49,13 +75,14 @@ void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t m
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
......
......@@ -217,14 +217,14 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
int64_t maskOutputTensor_virtual = true;
cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT;
auto maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE;
cudnn_frontend::TensorReordering_t::NONE;
if (is_bprop) {
maskOutputTensor_id = dS_ID;
maskOutputTensor_virtual = false;
maskOutputTensor_dataType = tensorType;
maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
cudnn_frontend::TensorReordering_t::F16x16;
}
auto maskOutputTensor =
......@@ -357,7 +357,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
// divide (e/ sum(e))
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
cudnn_frontend::TensorReordering_t::F16x16;
auto afterDivisionTensor =
cudnn_frontend::TensorBuilder()
......@@ -448,7 +448,7 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
afterBMM1_stride, true, false); // is virtual
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
cudnn_frontend::TensorReordering_t::F16x16;
// after dropout tensor
auto afterDropoutTensor =
......@@ -918,7 +918,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
cudnn_frontend::TensorReordering_t::F16x16;
// activation from fprop
auto pTensor =
......@@ -1246,7 +1246,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
......@@ -1312,8 +1312,8 @@ void fused_attn_max_512_fwd_qkvpacked(
}
}
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
......@@ -1389,8 +1389,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
......@@ -1460,7 +1460,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
......@@ -1519,8 +1519,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
}
}
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
......@@ -1580,8 +1580,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
......
......@@ -19,7 +19,7 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
......@@ -28,8 +28,8 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
......@@ -38,8 +38,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
......@@ -49,7 +49,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
......@@ -58,8 +58,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
......@@ -68,8 +68,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
......
......@@ -366,8 +366,7 @@ static cudnn_frontend::Tensor createDropoutForward(
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t::
CUDNN_TENSOR_REORDERING_F16x16)
.setReorderType(cudnn_frontend::TensorReordering_t::F16x16)
.build();
// Scale after dropout
auto scaleDropoutTensor = tensor_create(
......@@ -448,8 +447,7 @@ static cudnn_frontend::Tensor createDropoutBackward(
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t::
CUDNN_TENSOR_REORDERING_F16x16)
.setReorderType(cudnn_frontend::TensorReordering_t::F16x16)
.build();
// Scale after dropout (1 / (1 - p))
auto scaleDropoutTensor = tensor_create(
......@@ -992,7 +990,7 @@ static cudnn_frontend::Tensor createdSQBMM(
}
// fused attention FWD FP8
void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
......@@ -1305,7 +1303,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
}
// fused attention BWD FP8
void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
......@@ -1935,7 +1933,7 @@ void fused_attn_fp8_fwd_qkvpacked(
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d,
b, h, max_seqlen, max_seqlen, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2025,7 +2023,7 @@ void fused_attn_fp8_bwd_qkvpacked(
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d,
b, h, max_seqlen, max_seqlen, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2131,7 +2129,7 @@ void fused_attn_fp8_fwd(
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2224,7 +2222,7 @@ void fused_attn_fp8_bwd(
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
b, h, max_seqlen_q, max_seqlen_kv, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......
......@@ -14,8 +14,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
......@@ -30,8 +29,7 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
const Tensor *input_O,
......@@ -49,8 +47,7 @@ void fused_attn_fp8_bwd_qkvpacked(
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
......@@ -66,8 +63,7 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O,
......
......@@ -30,109 +30,6 @@ void generateMatrixStrides(
constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3;
// to be deprecated in the future
switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d;
} else if ((layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_S_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
break;
case NVTE_QKV_Matrix::NVTE_O_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
break;
}
// new way of getting strides
switch (layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
......@@ -497,4 +394,27 @@ cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine
......@@ -12,6 +12,7 @@
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cstdint>
#include <mutex>
......@@ -95,6 +96,34 @@ struct FADescriptor {
}
};
struct FADescriptor_v1 {
std::int64_t b;
std::int64_t h;
std::int64_t hg;
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
float attnScale;
bool isTraining;
float dropoutProbability;
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
cudnn_frontend::DataType_t tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
rhs.tensor_type);
}
};
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
......@@ -107,6 +136,7 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment