Unverified Commit 5c58beaa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add long sequence support for fused attention (#237)



* add long sequence support and unify three backends for fused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to v0.9.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cpu_float2half_rn with __float2half_rn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection and NVTEDType
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make cudnn plan caches thread_local
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix replacement of cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout probablity to 0 in inference mode
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change negInfinity to be consistent with m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove float2half conversion for scale_dropout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back runtime api for sm detection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gemm3 to enums FP8Fwd/BwdTensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change dropout from no to yes for fmha_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state in m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix elts_per_thread calculation in kvpacked fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove dropout=0.0 restriction for m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state completely from m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4330e025
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609 Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os
class ModelConfig:
def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
dropout_p, attn_mask_type,
):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
assert (hidden_size == num_attention_heads * head_dim
), """hidden_size must be = num_heads x head_dim."""
self.seq_len = seq_len
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dot_product_attention(dtype, bs, model):
"""Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dot_product_attention(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op.backward(op_grad)
return op, inp.grad
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention")
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
)
.to(dtype = dtype)
.cuda()
)
op = block(inp)
op.backward(op_grad)
return op, inp.grad
model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
"""Test DotProductAttention module with FP8,
using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""
config = model_configs_fp8[model]
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-2, 1e-1)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dpa_fp8(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
bs * config.seq_len, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
)
dpa = DPA_FP8(config).to(dtype = torch.float16).cuda()
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len)
op.backward(op_grad)
context = torch.load("ctx.pt")
dqkv = torch.load('dqkv.pt')
return (context.view(bs, config.seq_len, -1).transpose(0,1),
dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, bs, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op.backward(op_grad)
torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt')
return op, inp.grad
from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_S = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS = tex.FP8BwdTensors.GRAD_INPUT3
class _dpa_fp8(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor,
cu_seqlens: torch.Tensor,
num_attention_heads: int,
p_dropout: float,
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
assert inp.dim() == 2
in_features = qkv_weight.shape[-1]
h = num_attention_heads
d = in_features // h
b = cu_seqlens.numel() - 1
is_nl = False
if b < 4 and b > 1:
max_s = 512
is_nl = True
fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
inp,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
qkv_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
M = None
ZInv = None
philox_unpacked = None
qkv_out = ext.fp8_gemm(
qkv_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
torch.uint8,
workspace,
bias=qkv_bias,
use_bias=True,
out_index = META_QKV,
fp8_meta_tensor = fp8_meta["scaling_fwd"],
use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward,
)
qkv_out = qkv_out.view(-1, 3, h, d)
qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward,
tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
torch.save(qkv_out_fp16, 'qkv.pt')
# FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked(
is_training,
max_s,
cu_seqlens,
qkv_out,
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale = None,
dropout = p_dropout,
fast_zero_fill = fast_zero_fill,
qkv_layout = "qkv_interleaved",
attn_bias_type = "no_bias",
attn_mask_type = "padding",
rng_gen = None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors
context = context_.view(-1, in_features)
context_t = tex.fp8_transpose(context, fp8_dtype_forward)
ctx.save_for_backward(
inputmat_t, qkv_weight_t_fp8, workspace,
qkv_out,
context_, context_t,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv,
)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.fp8_meta = fp8_meta
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl
ctx.hidden_size = in_features
ctx.num_attention_heads = num_attention_heads
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16)
torch.save(context_fp16, 'ctx.pt')
return context_fp16
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
workspace,
qkv_out,
context, context_t,
fwd_scales,
fwd_scale_inverses,
) = ctx.saved_tensors
fp8_dtype_forward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_s,
ctx.cu_seqlens,
qkv_out,
context,
proj_dgrad.view_as(context),
fp8_dtype_forward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
None,
ctx.p_dropout,
ctx.fast_zero_fill,
"qkv_interleaved",
"no_bias",
"padding",
)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, tex.DType.kFloat16)
torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')
qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
torch.float16,
)
# QKV DGRAD
qkv_dgrad = ext.fp8_gemm(
qkv_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_DGRAD,
)
# QKV WGRAD
qkv_wgrad = ext.fp8_gemm(
inputmat_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
dqkv_grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_WGRAD,
)
return (qkv_dgrad,
qkv_wgrad,
qkv_bgrad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None)
class DPA_FP8(TransformerEngineBaseModule):
def __init__(
self,
config,
params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.fast_zero_fill = True
self.qkv_weight = Parameter(
torch.empty(
self.hidden_size * 3,
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = Parameter(
torch.empty(
self.hidden_size * 3,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self, inp: torch.Tensor,
cu_seqlens, max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _dpa_fp8.apply(
inp,
self.qkv_weight,
self.qkv_bias,
cu_seqlens,
self.h,
self.p_dropout,
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training)
return out
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
...@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
......
...@@ -7,8 +7,80 @@ ...@@ -7,8 +7,80 @@
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h" #include "fused_attn_f16_max512_seqlen.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ >= 80)
&& (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true;
}
if ((sm_arch_ >= 80)
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
&& (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
}
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND");
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
return backend;
}
// NVTE fused attention FWD FP8 with packed QKV // NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked( void nvte_fused_attn_fwd_qkvpacked(
...@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
...@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen <= 512)) { nvte_get_fused_attn_backend(
#if (CUDNN_VERSION >= 8900) QKV_type, QKV_type,
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type qkv_layout, bias_type, attn_mask_type,
fused_attn_fwd_fp8_qkvpacked( dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_output_S, output_O, input_QKV, input_Bias, output_O,
Aux_Output_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
&& (max_seqlen <= 512)) { #if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8901) fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_max_512_fwd_qkvpacked( b, max_seqlen, h, d,
b, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen, input_QKV, input_Bias, output_O,
h, Aux_CTX_Tensors,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens, input_cu_seqlens,
input_rng_state, input_rng_state,
wkspace, wkspace, stream, handle);
stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif #endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed QKV // NVTE fused attention BWD FP8 with packed QKV
...@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen <= 512)) { nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO,
...@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked(
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention FWD FP8 with packed KV // NVTE fused attention FWD FP8 with packed KV
...@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state, const NVTETensor rng_state,
...@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { nvte_get_fused_attn_backend(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); Q_type, KV_type,
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) qkv_layout, bias_type, attn_mask_type,
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked( fused_attn_max_512_fwd_kvpacked(
b, b, max_seqlen_q, max_seqlen_kv, h, d,
max_seqlen_q, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen_kv, input_Q, input_KV, input_Bias, output_O,
h, Aux_CTX_Tensors,
d, input_cu_seqlens_q, input_cu_seqlens_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
input_rng_state, input_rng_state,
wkspace, wkspace, stream, handle);
stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed KV // NVTE fused attention BWD FP8 with packed KV
...@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { nvte_get_fused_attn_backend(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); Q_type, KV_type,
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) qkv_layout, bias_type, attn_mask_type,
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked( fused_attn_max_512_bwd_kvpacked(
b, b, max_seqlen_q, max_seqlen_kv, h, d,
max_seqlen_q, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen_kv, input_Q, input_KV, input_dO,
h, output_S,
d, output_dQ, output_dKV, output_dBias,
attn_scale, input_cu_seqlens_q, input_cu_seqlens_kv,
dropout, wkspace, stream, handle);
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "fused_attn_f16_arbitrary_seqlen.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cudnn_frontend.h>
#include <map>
#include <vector>
#include "../common.h"
#include "utils.h"
#if (CUDNN_VERSION >= 8900)
#define Q_ID 1
#define K_ID 2
#define V_ID 3
#define O_ID 4
#define S_ID 5
#define B_ID 6
#define D_CONST_ID 7
#define S_CONST_ID 8
#define Q_SEQLEN_ID 9
#define K_SEQLEN_ID 10
#define dQ_ID 11
#define dK_ID 12
#define dV_ID 13
#define dO_ID 14
#define MASK_VAL_ID 15
#define dS_ID 16
#define D_SEED_ID 17
#define D_OFFSET_ID 18
#define S_STATS_ID 19
#define S_SUM_ID 20
#define SCALE_PROB 21
#define K_TRANSPOSE_ID 22
#define dQ_ACCUM_ID 23
#define VIRTUAL_ID 30
namespace transformer_engine {
namespace fused_attn {
static cudnn_frontend::Tensor
createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
const cudnn_frontend::Tensor& sTensor,
std::vector<cudnn_frontend::Operation>* ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto scaleTensor = tensor_create(
tensorType, S_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto sScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 2000, s_dim,
s_stride, true, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a scale node
auto scale_op = binary_pw_op_create(sTensor, scaleTensor, sScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return sScaleTensor;
}
static cudnn_frontend::Tensor
createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_dim, k_stride, false, false); // is virtual
// first GEMM output
auto sTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, s_dim, s_stride, true, false); // is virtual
// Define the matmul 1 desc
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a matmul 1 node
auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(sTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops->push_back(std::move(matmul_op1));
return sTensor;
}
static cudnn_frontend::Tensor
createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim,
maskVal_stride, false, true); // is by value
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// create causal mask (row >= col)
auto causalMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 107, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the greater than equal to comparison descriptor
auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE);
// Create a greater than equal to node
auto rowGreaterCol_op = binary_pw_op_create(
rowIndexTensor, columnIndexTensor, causalMaskTensor, rowGreaterColDesc);
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select node
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor,
causalMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(rowGreaterCol_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createSoftmaxForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, bool isTraining,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& sAfterMaskTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
// max (x)
auto afterMaxReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// x - max(x)
auto afterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x)))
auto afterAddReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// log (sum (e^(x - max(x))))
auto afterLogLTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 154, afterReduction_dim,
afterReduction_stride, true, false);
// M + log (sum (e^(x - max(x))))
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, afterReduction_dim,
afterReduction_stride, !isTraining, false);
// not virtual if training is true, virtual if training is false
// divide (e/ sum(e))
auto afterSoftmaxTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 156)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(sAfterMaskTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract node
auto subtract_op = binary_pw_op_create(
sAfterMaskTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent node
auto exponent_op = unary_pw_op_create(
afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(afterAddReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Create log descriptor
auto logDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_LOG);
// Create log node
auto log_op = unary_pw_op_create(afterAddReductionTensor, afterLogLTensor, logDesc);
// Create add descriptor
auto addDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD);
// Create add node
auto add_op = binary_pw_op_create(
afterMaxReductionTensor, afterLogLTensor,
softmaxStatsTensor, addDesc);
// Define the division descriptor
auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV);
// Create a subtract node
auto division_op = binary_pw_op_create(
afterExponentTensor, afterAddReductionTensor,
afterSoftmaxTensor, divisionDesc);
ops->push_back(std::move(reductionMax_op));
ops->push_back(std::move(subtract_op));
ops->push_back(std::move(exponent_op));
ops->push_back(std::move(reductionAdd_op));
ops->push_back(std::move(log_op));
ops->push_back(std::move(add_op));
ops->push_back(std::move(division_op));
return afterSoftmaxTensor;
}
static cudnn_frontend::Tensor
createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// mask for the dropout
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor
createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static void
createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
cudnn_frontend::Tensor const &afterScaleDropoutTensor) {
NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one");
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
// Define the matmul 2 desc
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a matmul 2 node
auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(afterScaleDropoutTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops->push_back(std::move(matmul_op2));
}
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout,
void *devPtrQ, void *devPtrK, void *devPtrV,
void *devPtrSoftmaxStats, void *devPtrO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) {
dropout_probability == 0.0f;
}
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
is_training, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
}
// otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Q * K^T
auto sTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops);
// Q * K^T * bmmScale
auto sScaleTensor = createScale(
b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops);
// Causual mask
auto sAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
NVTE_CHECK(dropout_probability != 1.0f,
"Dropout probability cannot be 1.0");
auto softmax_output = createSoftmaxForward(
b, h, s_q, s_kv, is_training, &ops, sAfterMaskTensor);
// Dropout(softmax)
auto dropout_output = createDropoutForward(
b, h, s_q, s_kv, d,
dropout_probability, tensorType, &ops, softmax_output);
createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dropout_output);
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig,
filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_fprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) {
*workspace_size = plan_workspace_size;
return;
}
std::set<std::pair<uint64_t, void*>> data_ptrs;
// Add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
// If training mode, we write out softmax stats
if (is_training) {
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
void* devPtrO, void* devPtrSoftmaxStats,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
true, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
auto it = cache.find(descriptor);
if (it != cache.end()) {
return it->second;
}
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, q_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_transpose_dim[4] = {b, h, d, s_kv};
int64_t k_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t v_transpose_dim[4] = {b, h, d, s_kv};
int64_t v_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, v_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose);
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, p_stride,
layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t p_transpose_stride[4];
p_transpose_stride[0] = p_stride[0];
p_transpose_stride[1] = p_stride[1];
p_transpose_stride[2] = p_stride[3];
p_transpose_stride[3] = p_stride[2];
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
/*******************************************************************************
* Dot product dO * O */
// output and gradient of the output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
auto dOTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto dotProductTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID, o_dim,
o_stride, true, false); // is virtual
// Create pointwise mul
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// do * O
auto dotProductOp = binary_pw_op_create(
dOTensor, oTensor, dotProductTensor, multiplyDesc);
ops.push_back(std::move(dotProductOp));
/*******************************************************************************
* Reduction(dO * O) */
int64_t reduction_dim[4] = {b, h, s_q, 1};
int64_t reduction_stride[4] = {h * s_q, s_q, 1, 1};
// reduction(dO * O)
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
ops.push_back(std::move(reductionMax_op));
/*******************************************************************************
* reduction(dO * O) * scale prob -> softmaxSum */
auto softmaxSumTensor = tensor_create(
CUDNN_DATA_FLOAT, S_SUM_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto scaleProbTensor = tensor_create(
CUDNN_DATA_FLOAT, SCALE_PROB, scale_dim,
scale_stride, false, true); // not virtual
auto softmaxSumOp = binary_pw_op_create(
afterReductionTensor, scaleProbTensor,
softmaxSumTensor, multiplyDesc);
ops.push_back(std::move(softmaxSumOp));
/*******************************************************************************
* Q @ K.T -> P */
// Inputs from fprop
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_transpose_dim,
k_transpose_stride, false, false);
auto pTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2, p_dim,
p_stride, true, false); // is virtual
// matmul to calculate dvTensor
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op0 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(pTensor)
.setmatmulDesc(matmul_0_Desc)
.build();
ops.push_back(std::move(matmul_op0));
/*******************************************************************************
* P * bmmScale -> pAfterScale */
auto bmmScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim,
scale_stride, false, true); // not virtual and by value
auto pAfterScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2000, p_dim,
p_stride, true, false); // virtual
auto scaleOp = binary_pw_op_create(
pTensor, bmmScaleTensor, pAfterScaleTensor, multiplyDesc);
ops.push_back(std::move(scaleOp));
/*******************************************************************************
* Causal masking -> pAfterMaskTensor */
auto pAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
/*******************************************************************************
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
auto pAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 3, p_dim,
p_stride, true, false); // is virtual
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto subtract_op = binary_pw_op_create(
pAfterMaskTensor, softmaxStatsTensor,
pAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op));
/*******************************************************************************
* e^(pAfterSubtract) -> pAfterSoftmax */
auto pAfterSoftmaxTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 4, p_dim,
p_stride, true, false); // is virtual
auto expDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
auto exp_op = unary_pw_op_create(
pAfterSubtractTensor, pAfterSoftmaxTensor, expDesc);
ops.push_back(std::move(exp_op));
/*******************************************************************************
* Dropout -> afterScaleDropout */
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 5, p_dim,
p_stride, true, false); // is virtual
auto afterScaleDropoutTensor = createDropoutBackward(
b, h, s_q, s_kv, d, dropout_probability, tensorType,
&ops, pAfterSoftmaxTensor, dropoutMaskTensor);
/*******************************************************************************
* afterScaleDropout -> sTransposeTensor */
auto sTransposeTensor = tensor_create(
tensorType, VIRTUAL_ID + 6, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(afterScaleDropoutTensor)
.setyDesc(sTransposeTensor)
.build();
ops.push_back(std::move(reshape_op));
// Outputs of bprop
int64_t dqkv_dim[4] = {b, h, s_kv, d};
int64_t dqkv_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, dqkv_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dqkv_dim, dqkv_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dqkv_dim, dqkv_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dqkv_dim, dqkv_stride, false, false);
// not virtual
/*******************************************************************************
* sTransposeTensor @ dO -> dV */
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op1 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor)
.setcMatDesc(dVTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops.push_back(std::move(matmul_op1));
/*******************************************************************************
* dO @ V.T -> dS */
auto vTransposeTensor = tensor_create(
tensorType, V_ID, v_transpose_dim,
v_transpose_stride, false, false);
auto dSTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 7, p_dim,
p_stride, true, false); // is virtual
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor)
.setcMatDesc(dSTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops.push_back(std::move(matmul_op2));
/*******************************************************************************
* dS * dropoutMask -> dSAfterDropout */
auto dSAfterDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 8, p_dim,
p_stride, true, false); // is virtual
auto multiply_op = binary_pw_op_create(
dSTensor, dropoutMaskTensor,
dSAfterDropoutTensor, multiplyDesc);
ops.push_back(std::move(multiply_op));
/*******************************************************************************
* dSAfterDropout - softmaxSum -> dsAfterSubtract */
auto dsAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 9, p_dim,
p_stride, true, false); // is virtual
auto subtract_op2 = binary_pw_op_create(
dSAfterDropoutTensor, softmaxSumTensor,
dsAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op2));
/*******************************************************************************
* dsAfterSubtract * afterSoftmax -> dP */
auto dPTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 10, p_dim,
p_stride, true, false); // is virtual
auto multiply_op2 = binary_pw_op_create(
dsAfterSubtractTensor, pAfterSoftmaxTensor,
dPTensor, multiplyDesc);
ops.push_back(std::move(multiply_op2));
/*******************************************************************************
* dP * scaleDropout -> dPAfterDropoutScale */
auto dPAfterDropoutScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 11, p_dim,
p_stride, true, false); // is virtual
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto multiply_op3 = binary_pw_op_create(
dPTensor, scaleDropoutTensor,
dPAfterDropoutScaleTensor, multiplyDesc);
ops.push_back(std::move(multiply_op3));
/*******************************************************************************
* dPAfterDropoutScale * bmmScale -> dPScaledTensor */
auto dPScaledTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 12, p_dim,
p_stride, true, false); // is virtual
auto multiply_op4 = binary_pw_op_create(
dPAfterDropoutScaleTensor, bmmScaleTensor,
dPScaledTensor, multiplyDesc);
ops.push_back(std::move(multiply_op4));
/*******************************************************************************
* K.T -> K */
int64_t kDim[4] = {b, h, s_kv, d};
int64_t kStride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, kStride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
auto kTensor = tensor_create(
tensorType, VIRTUAL_ID + 13, kDim,
kStride, true, false); // is virtual
auto reshape_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(kTransposeTensor)
.setyDesc(kTensor)
.build();
ops.push_back(std::move(reshape_op2));
/*******************************************************************************
* dP @ K -> dqAccumTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqkv_dim)
.setStride(4, dqkv_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(false)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dqAccumTensor)
.setmatmulDesc(matmul_3_Desc)
.build();
ops.push_back(std::move(matmul_op3));
/*******************************************************************************
* dP.T @ Q -> dK */
auto dPTransposeTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 14, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dPTensor)
.setyDesc(dPTransposeTensor)
.build();
ops.push_back(std::move(reshape_op3));
auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op4 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTransposeTensor)
.setbMatDesc(qTensor)
.setcMatDesc(dKTensor)
.setmatmulDesc(matmul_4_Desc)
.build();
ops.push_back(std::move(matmul_op4));
/*******************************************************************************
* dqAccumTensor @ identity -> dqTensor */
auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY);
auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc);
ops.push_back(std::move(identity_op));
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr, CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_bprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_bprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size
+ dqAccum_workspace_size;
return;
}
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size));
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator));
data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV));
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrKTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrVTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(dO_ID, devPtrdO));
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
data_ptrs.insert(std::pair<uint64_t, void*>(S_SUM_ID, devPtrSoftmaxSum));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
float scaleProb = 1.0f - dropout_probability;
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(SCALE_PROB, &scaleProb));
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
}
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
} // namespace transformer_engine
#endif // CUDNN_VERSION >= 8900
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_arbitrary_seqlen.h
* \brief Functions for fused attention with seqlen > 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "fused_attn_fp16_bf16_max_seqlen_512.h" #include "fused_attn_f16_max512_seqlen.h"
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked(
void *devPtrS = nullptr; void *devPtrS = nullptr;
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_Output_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype; output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_Output_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} }
...@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const DType kv_type = input_KV->data.dtype; const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_Output_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_Output_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} }
...@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
void *devPtrdBias = output_dBias->data.dptr; void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1); void *devPtrS = output_S->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS // devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS; void *devPtrdS = devPtrS;
...@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrdBias = output_dBias->data.dptr; void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1); void *devPtrS = output_S->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS // devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS; void *devPtrdS = devPtrS;
......
...@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
...@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......
...@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM( ...@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM(
} }
// fused attention FWD FP8 // fused attention FWD FP8
void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
bool isTraining, float attnScale, bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
...@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
} }
// fused attention BWD FP8 // fused attention BWD FP8
void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrM, void* devPtrZInv,
...@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
...@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens,
const Tensor *rng_state, const Tensor *rng_state,
Tensor *workspace, Tensor *workspace,
...@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked(
void* devPtrM = nullptr; void* devPtrM = nullptr;
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if (is_training) { if (is_training) {
Aux_Output_Tensors->size = 2; Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1}; output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr; output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1}; output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32; output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_Output_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} }
void* devPtrAmaxS = input_output_S->amax.dptr; void* devPtrAmaxS = input_output_S->amax.dptr;
...@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fa_fwd_fp8( fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d, b, max_seqlen, max_seqlen, h, d,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
...@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked(
} }
} }
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
...@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked( ...@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fa_bwd_fp8( fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d, b, max_seqlen, max_seqlen, h, d,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
...@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens,
const Tensor *rng_state, const Tensor *rng_state,
Tensor *workspace, Tensor *workspace,
...@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked(
cudnnHandle_t handle); cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......
...@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, ...@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
} }
} }
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -94,6 +94,38 @@ enum NVTE_Mask_Type { ...@@ -94,6 +94,38 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
}; };
enum NVTE_Fused_Attn_Backend {
/*! No supported backend */
NVTE_No_Backend = -1,
/*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */
NVTE_F16_max512_seqlen = 0,
/*! cuDNN-based FP16/BF16 fused attention for any sequence length */
NVTE_F16_arbitrary_seqlen = 1,
/*! cuDNN-based FP8 fused attention for <= 512 sequence length */
NVTE_FP8 = 2,
};
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
* Computes: * Computes:
...@@ -104,9 +136,10 @@ enum NVTE_Mask_Type { ...@@ -104,9 +136,10 @@ enum NVTE_Mask_Type {
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -114,11 +147,12 @@ enum NVTE_Mask_Type { ...@@ -114,11 +147,12 @@ enum NVTE_Mask_Type {
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing. * \param[in] max_seqlen Max sequence length used for computing,
* It may be >= max(cu_seqlens). * it may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -133,7 +167,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -133,7 +167,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
...@@ -147,9 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -147,9 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -158,12 +193,13 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -158,12 +193,13 @@ void nvte_fused_attn_fwd_qkvpacked(
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing. * \param[in] max_seqlen Max sequence length used for computing,
* It may be >= max(cu_seqlens). * it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -199,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -199,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -208,14 +244,15 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -208,14 +244,15 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* for Q. It may be >= max(cu_seqlens_q). * it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* for KV. It may be >= max(cu_seqlens_kv). * it may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state, const NVTETensor rng_state,
...@@ -246,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -246,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -256,16 +293,17 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -256,16 +293,17 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* for Q. It may be >= max(cu_seqlens_q). * it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* for KV. It may be >= max(cu_seqlens_kv). * it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
......
...@@ -15,6 +15,16 @@ import torch ...@@ -15,6 +15,16 @@ import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
divide, divide,
...@@ -26,6 +36,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -26,6 +36,7 @@ from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
AttnTypes, AttnTypes,
dist_group_type, dist_group_type,
TE_DType,
) )
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import ( from transformer_engine.pytorch.distributed import (
...@@ -272,9 +283,9 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -272,9 +283,9 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv return dq, dk, dv
def _check_if_interleaved(q, k, v): def _check_if_interleaved_qkv(q, k, v):
data_ptr = q.storage().data_ptr() data_ptr = q.untyped_storage().data_ptr()
check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v]) check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs: if not check_ptrs:
return False return False
...@@ -293,9 +304,32 @@ def _check_if_interleaved(q, k, v): ...@@ -293,9 +304,32 @@ def _check_if_interleaved(q, k, v):
for i, x in enumerate([q, k, v])) for i, x in enumerate([q, k, v]))
return check_offsets return check_offsets
def _check_if_interleaved_kv(k, v):
data_ptr = k.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
if not check_ptrs:
return False
stride = k.stride()
check_strides = all(stride == x.stride() for x in [k, v])
if not check_strides:
return False
shape = k.shape
check_shapes = all(shape == x.shape for x in [k, v])
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v]))
return check_offsets
class FlashAttention(torch.nn.Module): class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package. """Dot product attention, using HazyResearch flash-attn package:
https://github.com/HazyResearch/flash-attention
""" """
def __init__( def __init__(
...@@ -326,9 +360,9 @@ class FlashAttention(torch.nn.Module): ...@@ -326,9 +360,9 @@ class FlashAttention(torch.nn.Module):
"""flash-attn fprop""" """flash-attn fprop"""
assert ( assert (
(query_layer.dtype in [torch.float16, torch.bfloat16]) query_layer.dtype in [torch.float16, torch.bfloat16]
and (key_layer.dtype in [torch.float16, torch.bfloat16]) and key_layer.dtype in [torch.float16, torch.bfloat16]
and (value_layer.dtype in [torch.float16, torch.bfloat16]) and value_layer.dtype in [torch.float16, torch.bfloat16]
), 'FlashAttention currently only supports FP16 and BF16.' ), 'FlashAttention currently only supports FP16 and BF16.'
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
...@@ -338,7 +372,7 @@ class FlashAttention(torch.nn.Module): ...@@ -338,7 +372,7 @@ class FlashAttention(torch.nn.Module):
if (query_layer.shape[-1] == 128 and if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved(query_layer, key_layer, value_layer)): _check_if_interleaved_qkv(query_layer, key_layer, value_layer)):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer, key_layer,
value_layer) value_layer)
...@@ -374,6 +408,286 @@ class FlashAttention(torch.nn.Module): ...@@ -374,6 +408,286 @@ class FlashAttention(torch.nn.Module):
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous() return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input"""
@staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(qkv, out, cu_seqlens)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
qkv, out, cu_seqlens = ctx.saved_tensors
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (None, None, None, dqkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, dqkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttnFunc_kvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed KV input"""
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"]
cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
2. FusedAttnBackend["F16_arbitrary_seqlen"]
cuDNN based fused attention for FP16/BF16 and any sequence length.
Support matrix:
| backend | 1 | 2 |
| flash based | no | yes |
| cuDNN based | yes | yes |
| qkv dtype | fp16/bf16 | fp16/bf16 |
| attn_type | self/cross | self |
| qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | |
| mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias |
| dropout | yes | yes |
| max_seqlen | <=512 | any |
| head_dim | 64 | 64,128 |
| output dtype | fp16/bf16 | fp16/bf16 |
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
attention_type: str = "self",
) -> None:
super().__init__()
self.norm_factor = norm_factor
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
), 'FusedAttention only supports FP16 and BF16 data types.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FusedAttention only supports CUDA tensors.'
qkv_dtype = TE_DType[query_layer.dtype]
seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1]
seqlen_kv = key_layer.shape[0]
max_seqlen_q = seqlen_q
max_seqlen_kv = seqlen_kv
if self.attention_type == "self":
if _check_if_interleaved_qkv(query_layer, key_layer, value_layer):
query_layer = query_layer.unsqueeze(3)
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
# [s, b, h, 3, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 3)
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous()
else:
query_layer = query_layer.unsqueeze(2)
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
# [s, b, 3, h, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 2)
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(0, 1).contiguous()
# [total_seqs, 3, h, d]
mixed_layer = mixed_layer.view(
mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:]).contiguous()
qkv_layout = "qkv_interleaved"
max_seqlen = seqlen_q
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
with self.attention_dropout_ctx():
output = FusedAttnFunc_qkvpacked.apply(
self.training,
max_seqlen,
cu_seqlens,
mixed_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
)
output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()
if self.attention_type == "cross":
if _check_if_interleaved_kv(key_layer, value_layer):
# [s, b, h, 2, d]
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
key_value = torch.cat([key_layer, value_layer], dim = 3)
# [b, s, 2, h, d]
key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous()
else:
# [s, b, 2, h, d]
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
key_value = torch.cat([key_layer, value_layer], dim = 2)
# [b, s, 2, h, d]
key_value = key_value.transpose(0, 1).contiguous()
# [total_seqs, 2, h, d]
query_layer = query_layer.transpose(0, 1).contiguous()
query_layer = query_layer.view(
query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:])
key_value = key_value.view([key_value.shape[0] * key_value.shape[1]]
+ key_value.shape[2:]).contiguous()
qkv_layout = "kv_interleaved"
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * seqlen_kv,
step=seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
with self.attention_dropout_ctx():
outputs = FusedAttnFunc_kvpacked.apply(
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
query_layer, key_value,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
)
output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(),
outputs[1].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous())
return output
class DotProductAttention(torch.nn.Module): class DotProductAttention(torch.nn.Module):
"""Allows the model to jointly attend to information from different """Allows the model to jointly attend to information from different
representation subspaces as described in the paper: representation subspaces as described in the paper:
...@@ -427,15 +741,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -427,15 +741,16 @@ class DotProductAttention(torch.nn.Module):
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attention_type: str = "self",
) -> None: ) -> None:
super().__init__() super().__init__()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
projection_size = kv_channels * num_attention_heads projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_size) self.hidden_size_per_partition = divide(projection_size, self.tp_size)
self.hidden_size_per_attention_head = divide( self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads projection_size, num_attention_heads
) )
...@@ -452,18 +767,28 @@ class DotProductAttention(torch.nn.Module): ...@@ -452,18 +767,28 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= 8.0
) )
self.use_fused_attention = (
int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= 8.0
)
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type, "attn_mask_type": attn_mask_type,
} }
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
if self.use_flash_attention: if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs) self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn # Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(
norm_factor, **attn_kwargs,
attention_type = attention_type)
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number) norm_factor, **attn_kwargs, layer_number=layer_number)
...@@ -494,6 +819,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -494,6 +819,9 @@ class DotProductAttention(torch.nn.Module):
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -511,6 +839,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -511,6 +839,17 @@ class DotProductAttention(torch.nn.Module):
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned. * :attr:`kv_channels`) is returned.
.. note::
`DotProductAttention` supports three backends: 1) `FlashAttention` which calls
HazyResearch's FlashAttention PyTorch API, 2) `FusedAttention` which has multiple
fused attention implementations as its backends (see `FusedAttention` for
more details), and 3) `UnfusedDotProductAttention` which is the native PyTorch
implementation with fused scaled masked softmax. Users can use environment variables
`NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, and `NVTE_FUSED_ATTN_BACKEND` to control
which DotProductAttention backend, and FusedAttention backend if applicable, to use.
The default DotProductAttention backend is 1.
Parameters Parameters
---------- ----------
query_layer : torch.Tensor query_layer : torch.Tensor
...@@ -526,9 +865,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -526,9 +865,17 @@ class DotProductAttention(torch.nn.Module):
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, defautl = `True`
Whether to use the fast path to set output tensors to 0 or not.
""" """
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
...@@ -538,9 +885,26 @@ class DotProductAttention(torch.nn.Module): ...@@ -538,9 +885,26 @@ class DotProductAttention(torch.nn.Module):
if self.attn_mask_type == "padding" and attention_mask is not None: if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False use_flash_attention = False
use_fused_attention = False
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
use_fused_attention = False
qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved"
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
self.attention_dropout,
query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1])
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = use_fused_attention and is_backend_avail
if use_flash_attention: if use_flash_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -550,6 +914,22 @@ class DotProductAttention(torch.nn.Module): ...@@ -550,6 +914,22 @@ class DotProductAttention(torch.nn.Module):
value_layer) value_layer)
return self.flash_attention(query_layer, key_layer, value_layer) return self.flash_attention(query_layer, key_layer, value_layer)
if use_fused_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention,
query_layer,
key_layer,
value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -752,6 +1132,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -752,6 +1132,9 @@ class MultiHeadAttention(torch.nn.Module):
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD""" """MultiHeadAttention FWD"""
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
...@@ -952,7 +1335,10 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -952,7 +1335,10 @@ class MultiHeadAttention(torch.nn.Module):
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention = checkpoint_core_attention,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill,
) )
# ================= # =================
......
...@@ -22,7 +22,7 @@ TE_DType = { ...@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding") AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -7,6 +7,12 @@ import math ...@@ -7,6 +7,12 @@ import math
from typing import Tuple, List, Union from typing import Tuple, List, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine_extensions import (
NVTE_QKV_Layout,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Fused_Attn_Backend
)
__all__ = ['fused_attn_fwd_qkvpacked', __all__ = ['fused_attn_fwd_qkvpacked',
...@@ -24,6 +30,34 @@ TORCH_DType = { ...@@ -24,6 +30,34 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
} }
QKVLayout = {
"not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
}
AttnBiasType = {
"no_bias": NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
}
AttnMaskType = {
"no_mask": NVTE_Mask_Type.NVTE_NO_MASK,
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
"FP8": NVTE_Fused_Attn_Backend.NVTE_FP8,
"No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend,
}
BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
def check_tensor(x: torch.Tensor): def check_tensor(x: torch.Tensor):
"""Check tensor properties.""" """Check tensor properties."""
...@@ -109,7 +143,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -109,7 +143,8 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
qkv: torch.Tensor, qkv: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
bias: torch.Tensor = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
...@@ -117,9 +152,9 @@ def fused_attn_fwd_qkvpacked( ...@@ -117,9 +152,9 @@ def fused_attn_fwd_qkvpacked(
amax_o: torch.Tensor = None, amax_o: torch.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -139,8 +174,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -139,8 +174,10 @@ def fused_attn_fwd_qkvpacked(
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None fused_attention_backend: tex.NVTE_Fused_Attn_Backend
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
...@@ -158,12 +195,12 @@ def fused_attn_fwd_qkvpacked( ...@@ -158,12 +195,12 @@ def fused_attn_fwd_qkvpacked(
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
set_zero: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method; if True, initializes the output tensor O to zero using the fast filling method;
if False, doesn't initialize O after its allocation if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "no_mask"}
...@@ -178,15 +215,26 @@ def fused_attn_fwd_qkvpacked( ...@@ -178,15 +215,26 @@ def fused_attn_fwd_qkvpacked(
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state] if is_training is False, aux_ctx_tensors = None
softmax-related tensors:
1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
softmax: torch.Tensor
Softmax(Q*K.T)
shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32
2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
softmaxStats: torch.Tensor
log(sum(e^(x - max(x)))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
3. if fused_attention_backend == FusedAttnBackend["FP8"]
M: torch.Tensor M: torch.Tensor
max(Q*K.T) max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32 shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T 1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32 shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator; state of the random number generator;
[seed, offset], dtype uint64 [seed, offset], dtype uint64
""" """
...@@ -203,60 +251,58 @@ def fused_attn_fwd_qkvpacked( ...@@ -203,60 +251,58 @@ def fused_attn_fwd_qkvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias": if attn_bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias." assert (attn_bias is not None
assert (bias.shape == [1, h, max_seqlen, max_seqlen] ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen]
assert (bias.dtype == qkv.dtype ), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
), "bias tensor must be in the same dtype as qkv." assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv."
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64): assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (qkv_layout == "qkv_interleaved" ), "Fused attention does not support this input combination."
and bias_type == "no_bias"
and attn_mask_type == "padding" # BF16/FP16 fused attention API from fmha_v1 apex
), """The FP8 fused attention API currently only supports qkv_interleaved layout, if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
no_bias type, and padding attention mask type.""" rng_elts_per_thread = (max_seqlen * max_seqlen
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_o is not None), "q_scale_o is required for the FP8 API." # BF16/FP16 fused attention API from fmha_v2
assert (amax_s is not None), "amax_s is required for the FP8 API." if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
assert (amax_o is not None), "amax_o is required for the FP8 API." rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (max_seqlen * max_seqlen
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
check_scalar(d_scale_qkv) check_scalar(d_scale_qkv)
check_scalar(q_scale_s) check_scalar(q_scale_s)
check_scalar(q_scale_o) check_scalar(q_scale_o)
check_scalar(amax_s) check_scalar(amax_s)
check_scalar(amax_o) check_scalar(amax_o)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked( output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d, b, max_seqlen, total_seqs, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, fast_zero_fill,
cu_seqlens, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
qkv, cu_seqlens, qkv, qkv_dtype,
qkv_dtype, d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
d_scale_qkv, rng_gen, rng_elts_per_thread,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
) )
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:] return output_tensors[0], output_tensors[1:]
...@@ -267,7 +313,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -267,7 +313,8 @@ def fused_attn_bwd_qkvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None, aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -279,9 +326,9 @@ def fused_attn_bwd_qkvpacked( ...@@ -279,9 +326,9 @@ def fused_attn_bwd_qkvpacked(
amax_dqkv: torch.Tensor = None, amax_dqkv: torch.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input. """Fused Attention BWD for packed QKV input.
...@@ -306,6 +353,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -306,6 +353,8 @@ def fused_attn_bwd_qkvpacked(
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -330,12 +379,12 @@ def fused_attn_bwd_qkvpacked( ...@@ -330,12 +379,12 @@ def fused_attn_bwd_qkvpacked(
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
set_zero: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method; if True, initializes the output tensor O to zero using the fast filling method;
if False, doesn't initialize O after its allocation if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "no_mask"}
...@@ -345,8 +394,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -345,8 +394,8 @@ def fused_attn_bwd_qkvpacked(
d_qkv: torch.Tensor d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV gradient tensor of QKV; same data type and shape as QKV
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens) check_cu_seqlens(cu_seqlens)
...@@ -363,29 +412,27 @@ def fused_attn_bwd_qkvpacked( ...@@ -363,29 +412,27 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1 assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element." ), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1] rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state) check_rng_state(rng_state)
# FP8 fused attention API if fused_attention_backend == FusedAttnBackend["FP8"]:
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64: assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
assert (qkv_layout == "qkv_interleaved" assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
and bias_type == "no_bias" assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
and attn_mask_type == "padding" assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
), """The FP8 fused attention API currently only supports qkv_interleaved layout, assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
no_bias type, and padding attention mask type.""" assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
assert (d_scale_s is not None), "d_scale_s is required for the FP8 API." assert (amax_dp is not None), "amax_dp is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for the FP8 API." assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API."
assert (amax_dp is not None), "amax_dp is required for the FP8 API."
assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API."
assert (len(aux_ctx_tensors) == 3 assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API." ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv) check_scalar(d_scale_qkv)
check_scalar(d_scale_s) check_scalar(d_scale_s)
check_scalar(d_scale_o) check_scalar(d_scale_o)
...@@ -399,37 +446,21 @@ def fused_attn_bwd_qkvpacked( ...@@ -399,37 +446,21 @@ def fused_attn_bwd_qkvpacked(
check_stats(m, b, h, max_seqlen) check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen) check_stats(z_inv, b, h, max_seqlen)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d, b, max_seqlen, total_seqs, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, fast_zero_fill,
cu_seqlens, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
qkv, o, d_o, cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
amax_dp, amax_dqkv,
) )
if bias_type == "no_bias": if attn_bias_type == "no_bias":
# return d_qkv when bias_type is no_bias # return d_qkv when attn_bias_type is no_bias
return output_tensors[0]
# otherwise return (d_qkv, d_bias)
return output_tensors return output_tensors
# otherwise return (d_qkv, d_bias)
return output_tensors[0], output_tensors[1]
def fused_attn_fwd_kvpacked( def fused_attn_fwd_kvpacked(
...@@ -441,7 +472,8 @@ def fused_attn_fwd_kvpacked( ...@@ -441,7 +472,8 @@ def fused_attn_fwd_kvpacked(
q: torch.Tensor, q: torch.Tensor,
kv: torch.Tensor, kv: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
bias: torch.Tensor = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
...@@ -449,9 +481,9 @@ def fused_attn_fwd_kvpacked( ...@@ -449,9 +481,9 @@ def fused_attn_fwd_kvpacked(
amax_o: torch.Tensor = None, amax_o: torch.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -479,8 +511,10 @@ def fused_attn_fwd_kvpacked( ...@@ -479,8 +511,10 @@ def fused_attn_fwd_kvpacked(
where total_seqs_kv = cu_seqlens_kv[-1] where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of Q and KV; in tex.DType, not torch.dtype data type of Q and KV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None fused_attention_backend: tex.NVTE_Fused_Attn_Backend
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
...@@ -498,12 +532,12 @@ def fused_attn_fwd_kvpacked( ...@@ -498,12 +532,12 @@ def fused_attn_fwd_kvpacked(
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
set_zero: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method; if True, initializes the output tensor O to zero using the fast filling method;
if False, doesn't initialize O after its allocation if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "no_mask"}
...@@ -518,15 +552,26 @@ def fused_attn_fwd_kvpacked( ...@@ -518,15 +552,26 @@ def fused_attn_fwd_kvpacked(
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state] if is_training is False, aux_ctx_tensors = None
softmax-related tensors:
1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
softmax: torch.Tensor
Softmax(Q*K.T)
shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32
2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
softmaxStats: torch.Tensor
log(sum(e^(x - max(x)))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
3. if fused_attention_backend == FusedAttnBackend["FP8"]
M: torch.Tensor M: torch.Tensor
max(Q*K.T) max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32 shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
ZInv: torch.Tensor ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T 1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32 shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
rng_state: torch.Tensor rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator; state of the random number generator;
[seed, offset], dtype uint64 [seed, offset], dtype uint64
""" """
...@@ -551,49 +596,42 @@ def fused_attn_fwd_kvpacked( ...@@ -551,49 +596,42 @@ def fused_attn_fwd_kvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias": if attn_bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias." assert (attn_bias is not None
assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv] ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
assert (bias.dtype == q.dtype ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
), "bias tensor must be in the same dtype as q and kv." assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
# FP8 fused attention API assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ ), "Fused attention does not support this input combination."
and (d == 64):
assert False, "The FP8 fused attention API currently only supports packed QKV input."
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex # BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv
# add BF/FP16 support for <=512 sequence length + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else: # BF16/FP16 fused attention API from fmha_v2
assert False, "No support for this dtype and max_seqlen combination." if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, fast_zero_fill,
cu_seqlens_q, cu_seqlens_kv, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
q, kv, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
qkv_dtype, d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
d_scale_qkv, attn_bias, rng_gen, rng_elts_per_thread,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
) )
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:] return output_tensors[0], output_tensors[1:]
...@@ -607,7 +645,8 @@ def fused_attn_bwd_kvpacked( ...@@ -607,7 +645,8 @@ def fused_attn_bwd_kvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None, aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -619,9 +658,9 @@ def fused_attn_bwd_kvpacked( ...@@ -619,9 +658,9 @@ def fused_attn_bwd_kvpacked(
amax_dqkv: torch.Tensor = None, amax_dqkv: torch.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
set_zero: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input. """Fused Attention BWD for packed KV input.
...@@ -654,6 +693,8 @@ def fused_attn_bwd_kvpacked( ...@@ -654,6 +693,8 @@ def fused_attn_bwd_kvpacked(
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -679,12 +720,12 @@ def fused_attn_bwd_kvpacked( ...@@ -679,12 +720,12 @@ def fused_attn_bwd_kvpacked(
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
set_zero: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method; if True, initializes the output tensor O to zero using the fast filling method;
if False, doesn't initialize O after its allocation if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "no_mask"}
...@@ -696,8 +737,8 @@ def fused_attn_bwd_kvpacked( ...@@ -696,8 +737,8 @@ def fused_attn_bwd_kvpacked(
d_kv: torch.Tensor d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV gradient tensor of KV; same data type and shape as KV
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens_q) check_cu_seqlens(cu_seqlens_q)
...@@ -722,45 +763,52 @@ def fused_attn_bwd_kvpacked( ...@@ -722,45 +763,52 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1 assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element." ), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1] rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state) check_rng_state(rng_state)
# FP8 fused attention API if fused_attention_backend == FusedAttnBackend["FP8"]:
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
and d == 64: assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert False, "The FP8 fused attention API currently only supports packed QKV input." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
############### BF16/FP16 fused attention API from fmha_v2 ################ assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
and (max_seqlen_q > 512) and (max_seqlen_kv > 512): assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
# add BF/FP16 support for >512 sequence length assert (amax_dp is not None), "amax_dp is required for FP8 fused attention."
assert False, "The BF16/FP16 support for >512 sequence length is coming!" assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
############### BF16/FP16 fused attention API from fmha_v1 apex ################ ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ check_scalar(d_scale_qkv)
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): check_scalar(d_scale_s)
# add BF/FP16 support for <=512 sequence length check_scalar(d_scale_o)
assert False, "The BF16/FP16 support for <=512 sequence length is coming!" check_scalar(d_scale_do)
check_scalar(q_scale_s)
else: check_scalar(q_scale_dp)
assert False, "No support for this dtype and max_seqlen combination." check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel # execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, fast_zero_fill,
cu_seqlens_q, cu_seqlens_kv, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
q, kv, o, d_o, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
amax_dp, amax_dqkv,
) )
# returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias) if attn_bias_type == "no_bias":
if bias_type == "no_bias": # return (d_q, d_kv) when attn_bias_type is no_bias
return output_tensors[:2]
return output_tensors return output_tensors
# otherwise return (d_q, d_kv), d_bias
return output_tensors[:2], output_tensors[2]
...@@ -58,7 +58,10 @@ enum FP8FwdTensors { ...@@ -58,7 +58,10 @@ enum FP8FwdTensors {
GEMM1_OUTPUT = 2, GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3, GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4, GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5 GEMM2_OUTPUT = 5,
GEMM3_INPUT = 6,
GEMM3_WEIGHT = 7,
GEMM3_OUTPUT = 8
}; };
// Used as named indices on the `scale`, `scale_inv`, // Used as named indices on the `scale`, `scale_inv`,
...@@ -67,7 +70,9 @@ enum FP8BwdTensors { ...@@ -67,7 +70,9 @@ enum FP8BwdTensors {
GRAD_OUTPUT1 = 0, GRAD_OUTPUT1 = 0,
GRAD_INPUT1 = 1, GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2, GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3 GRAD_INPUT2 = 3,
GRAD_OUTPUT3 = 4,
GRAD_INPUT3 = 5
}; };
...@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, ...@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt32: case transformer_engine::DType::kInt32:
return torch::kInt32;
case transformer_engine::DType::kInt64:
return torch::kInt64;
case transformer_engine::DType::kFloat32: case transformer_engine::DType::kFloat32:
return at::kFloat; return at::kFloat;
case transformer_engine::DType::kFloat16: case transformer_engine::DType::kFloat16:
......
...@@ -12,43 +12,21 @@ ...@@ -12,43 +12,21 @@
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
// convert QKV layout to enum // get the fused attention backend
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) { NVTE_Fused_Attn_Backend get_fused_attn_backend(
if (qkv_layout == "not_interleaved") { const transformer_engine::DType q_dtype,
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED; const transformer_engine::DType kv_dtype,
} else if (qkv_layout == "qkv_interleaved") { NVTE_QKV_Layout qkv_layout,
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; NVTE_Bias_Type bias_type,
} else if (qkv_layout == "kv_interleaved") { NVTE_Mask_Type attn_mask_type,
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; float p_dropout, size_t max_seqlen_q,
} else { size_t max_seqlen_kv, size_t head_dim) {
NVTE_ERROR("Invalid QKV layout. \n"); NVTE_Fused_Attn_Backend fused_attention_backend =
} nvte_get_fused_attn_backend(
} static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type,
// convert bias type to enum p_dropout, max_seqlen_q, max_seqlen_kv, head_dim);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { return fused_attention_backend;
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
} }
// fast zero-fills of tensors // fast zero-fills of tensors
...@@ -103,10 +81,8 @@ __global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { ...@@ -103,10 +81,8 @@ __global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
// extract PhiloxCudaState from CUDA random number generator // extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state( at::PhiloxCudaState init_philox_state(
at::CUDAGeneratorImpl* gen, at::CUDAGeneratorImpl* gen,
size_t max_seq_len, size_t elts_per_thread) {
size_t threads_per_cta) {
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta;
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread); philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args; return philox_args;
...@@ -117,7 +93,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -117,7 +93,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero, bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
...@@ -127,15 +103,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -127,15 +103,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) { const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
// create output tensor O // create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs), auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options); static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) { if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
} }
// construct NVTE tensors // construct NVTE tensors
...@@ -166,7 +145,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -166,7 +145,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if ((bias_type != "no_bias") && (Bias.has_value())) { if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
...@@ -175,23 +154,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -175,23 +154,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset // extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128; at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr())); philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
...@@ -209,7 +181,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -209,7 +181,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_rng_state.data(), te_rng_state.data(),
max_seqlen, max_seqlen,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -219,10 +191,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -219,10 +191,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
workspace_data.data_ptr(), workspace_data.data_ptr(),
workspace.shape(), workspace.dtype()); workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] // output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<at::Tensor> output_tensors; std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O); output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]); auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
...@@ -230,9 +201,6 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -230,9 +201,6 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
output_tensors.push_back(output_tensor); output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr(); tensor->data.dptr = output_tensor.data_ptr();
} }
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
...@@ -245,14 +213,14 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -245,14 +213,14 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_rng_state.data(), te_rng_state.data(),
max_seqlen, max_seqlen,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O] // if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors; return output_tensors;
} }
...@@ -261,7 +229,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -261,7 +229,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const at::Tensor O, const at::Tensor O,
...@@ -281,13 +249,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -281,13 +249,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// create output tensor dQKV // create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV); at::Tensor dQKV = torch::empty_like(QKV);
if (set_zero) { auto max_tokens = dQKV.size(0);
auto self_2d = dQKV.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
} }
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != "no_bias") { if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h), dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen), static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options); static_cast<int64_t>(max_seqlen)}, options);
...@@ -341,13 +314,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -341,13 +314,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors // convert auxiliary tensors from forward into NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
...@@ -380,7 +347,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -380,7 +347,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_cu_seqlens.data(), te_cu_seqlens.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -403,7 +370,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -403,7 +370,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_cu_seqlens.data(), te_cu_seqlens.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -419,7 +386,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -419,7 +386,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero, bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
...@@ -431,15 +398,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -431,15 +398,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) { const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
// create output tensor O // create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q), auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options); static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) { if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
} }
// construct NVTE tensors // construct NVTE tensors
...@@ -474,7 +444,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -474,7 +444,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if ((bias_type != "no_bias") && (Bias.has_value())) { if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
...@@ -485,24 +455,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -485,24 +455,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128; at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
at::PhiloxCudaState philox_args = init_philox_state(
gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr())); philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
...@@ -522,7 +484,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -522,7 +484,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -532,10 +494,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -532,10 +494,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
workspace_data.data_ptr(), workspace_data.data_ptr(),
workspace.shape(), workspace.dtype()); workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] // output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<at::Tensor> output_tensors; std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O); output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]); auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
...@@ -543,9 +504,6 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -543,9 +504,6 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
output_tensors.push_back(output_tensor); output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr(); tensor->data.dptr = output_tensor.data_ptr();
} }
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
...@@ -560,14 +518,14 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -560,14 +518,14 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O] // if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors; return output_tensors;
} }
...@@ -577,7 +535,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -577,7 +535,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
...@@ -600,14 +558,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -600,14 +558,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
// create output tensors dQ and dKV // create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q); at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV); at::Tensor dKV = torch::empty_like(KV);
if (set_zero) { auto max_tokens_q = dQ.size(0);
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQ.fill_(0);
dKV.fill_(0);
} }
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != "no_bias") { if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h), dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options); static_cast<int64_t>(max_seqlen_kv)}, options);
...@@ -674,13 +641,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -674,13 +641,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
...@@ -711,7 +672,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -711,7 +672,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -737,7 +698,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -737,7 +698,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, qkv_layout, bias_type, attn_mask_type,
workspace.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -2227,6 +2188,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2227,6 +2188,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
...@@ -2279,11 +2241,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2279,11 +2241,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT); .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors") py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2); .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
} }
...@@ -7,17 +7,22 @@ ...@@ -7,17 +7,22 @@
#include "common.h" #include "common.h"
#include "../common.h" #include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout); NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type); const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type); NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d, bool is_training,
bool is_training, float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
...@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen); const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d, float attn_scale,
float attn_scale, float p_dropout, bool set_zero, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const at::Tensor O, const at::Tensor O,
...@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d, bool is_training,
bool is_training, float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
...@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen); const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d, float attn_scale,
float attn_scale, float p_dropout, bool set_zero, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
......
...@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
""" """
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
...@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module):
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment