Unverified Commit 5c58beaa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add long sequence support for fused attention (#237)



* add long sequence support and unify three backends for fused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to v0.9.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cpu_float2half_rn with __float2half_rn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection and NVTEDType
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make cudnn plan caches thread_local
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix replacement of cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout probablity to 0 in inference mode
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change negInfinity to be consistent with m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove float2half conversion for scale_dropout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back runtime api for sm detection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gemm3 to enums FP8Fwd/BwdTensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change dropout from no to yes for fmha_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state in m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix elts_per_thread calculation in kvpacked fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove dropout=0.0 restriction for m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state completely from m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4330e025
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609
Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os
class ModelConfig:
def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
dropout_p, attn_mask_type,
):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
assert (hidden_size == num_attention_heads * head_dim
), """hidden_size must be = num_heads x head_dim."""
self.seq_len = seq_len
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dot_product_attention(dtype, bs, model):
"""Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dot_product_attention(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op.backward(op_grad)
return op, inp.grad
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention")
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
)
.to(dtype = dtype)
.cuda()
)
op = block(inp)
op.backward(op_grad)
return op, inp.grad
model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
"""Test DotProductAttention module with FP8,
using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""
config = model_configs_fp8[model]
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, bs, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-2, 1e-1)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dpa_fp8(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
bs * config.seq_len, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
)
dpa = DPA_FP8(config).to(dtype = torch.float16).cuda()
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len)
op.backward(op_grad)
context = torch.load("ctx.pt")
dqkv = torch.load('dqkv.pt')
return (context.view(bs, config.seq_len, -1).transpose(0,1),
dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, bs, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op.backward(op_grad)
torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt')
return op, inp.grad
from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_S = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS = tex.FP8BwdTensors.GRAD_INPUT3
class _dpa_fp8(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor,
cu_seqlens: torch.Tensor,
num_attention_heads: int,
p_dropout: float,
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
assert inp.dim() == 2
in_features = qkv_weight.shape[-1]
h = num_attention_heads
d = in_features // h
b = cu_seqlens.numel() - 1
is_nl = False
if b < 4 and b > 1:
max_s = 512
is_nl = True
fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
inp,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
qkv_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
M = None
ZInv = None
philox_unpacked = None
qkv_out = ext.fp8_gemm(
qkv_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
torch.uint8,
workspace,
bias=qkv_bias,
use_bias=True,
out_index = META_QKV,
fp8_meta_tensor = fp8_meta["scaling_fwd"],
use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward,
)
qkv_out = qkv_out.view(-1, 3, h, d)
qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward,
tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
torch.save(qkv_out_fp16, 'qkv.pt')
# FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked(
is_training,
max_s,
cu_seqlens,
qkv_out,
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale = None,
dropout = p_dropout,
fast_zero_fill = fast_zero_fill,
qkv_layout = "qkv_interleaved",
attn_bias_type = "no_bias",
attn_mask_type = "padding",
rng_gen = None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors
context = context_.view(-1, in_features)
context_t = tex.fp8_transpose(context, fp8_dtype_forward)
ctx.save_for_backward(
inputmat_t, qkv_weight_t_fp8, workspace,
qkv_out,
context_, context_t,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv,
)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.fp8_meta = fp8_meta
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl
ctx.hidden_size = in_features
ctx.num_attention_heads = num_attention_heads
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16)
torch.save(context_fp16, 'ctx.pt')
return context_fp16
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
workspace,
qkv_out,
context, context_t,
fwd_scales,
fwd_scale_inverses,
) = ctx.saved_tensors
fp8_dtype_forward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_s,
ctx.cu_seqlens,
qkv_out,
context,
proj_dgrad.view_as(context),
fp8_dtype_forward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
None,
ctx.p_dropout,
ctx.fast_zero_fill,
"qkv_interleaved",
"no_bias",
"padding",
)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, tex.DType.kFloat16)
torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')
qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
torch.float16,
)
# QKV DGRAD
qkv_dgrad = ext.fp8_gemm(
qkv_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dqkv_grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_DGRAD,
)
# QKV WGRAD
qkv_wgrad = ext.fp8_gemm(
inputmat_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
dqkv_grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_WGRAD,
)
return (qkv_dgrad,
qkv_wgrad,
qkv_bgrad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None)
class DPA_FP8(TransformerEngineBaseModule):
def __init__(
self,
config,
params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.fast_zero_fill = True
self.qkv_weight = Parameter(
torch.empty(
self.hidden_size * 3,
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.fp8_weight_shapes.append(self.qkv_weight.shape)
self.qkv_bias = Parameter(
torch.empty(
self.hidden_size * 3,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self, inp: torch.Tensor,
cu_seqlens, max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _dpa_fp8.apply(
inp,
self.qkv_weight,
self.qkv_bias,
cu_seqlens,
self.h,
self.p_dropout,
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training)
return out
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
......@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
......
......@@ -7,8 +7,80 @@
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_f16_max512_seqlen.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ >= 80)
&& (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true;
}
if ((sm_arch_ >= 80)
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
&& (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
}
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND");
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
return backend;
}
// NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked(
......@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
......@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype;
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900)
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_Output_Tensors,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b,
max_seqlen,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace,
stream,
handle);
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD FP8 with packed QKV
......@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype;
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
fused_attn_fp8_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO,
......@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked(
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention FWD FP8 with packed KV
......@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
......@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype;
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace,
stream,
handle);
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD FP8 with packed KV
......@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype;
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_dO,
output_S,
output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "fused_attn_f16_arbitrary_seqlen.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cudnn_frontend.h>
#include <map>
#include <vector>
#include "../common.h"
#include "utils.h"
#if (CUDNN_VERSION >= 8900)
#define Q_ID 1
#define K_ID 2
#define V_ID 3
#define O_ID 4
#define S_ID 5
#define B_ID 6
#define D_CONST_ID 7
#define S_CONST_ID 8
#define Q_SEQLEN_ID 9
#define K_SEQLEN_ID 10
#define dQ_ID 11
#define dK_ID 12
#define dV_ID 13
#define dO_ID 14
#define MASK_VAL_ID 15
#define dS_ID 16
#define D_SEED_ID 17
#define D_OFFSET_ID 18
#define S_STATS_ID 19
#define S_SUM_ID 20
#define SCALE_PROB 21
#define K_TRANSPOSE_ID 22
#define dQ_ACCUM_ID 23
#define VIRTUAL_ID 30
namespace transformer_engine {
namespace fused_attn {
static cudnn_frontend::Tensor
createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
const cudnn_frontend::Tensor& sTensor,
std::vector<cudnn_frontend::Operation>* ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto scaleTensor = tensor_create(
tensorType, S_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto sScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 2000, s_dim,
s_stride, true, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a scale node
auto scale_op = binary_pw_op_create(sTensor, scaleTensor, sScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return sScaleTensor;
}
static cudnn_frontend::Tensor
createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_dim, k_stride, false, false); // is virtual
// first GEMM output
auto sTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, s_dim, s_stride, true, false); // is virtual
// Define the matmul 1 desc
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a matmul 1 node
auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(sTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops->push_back(std::move(matmul_op1));
return sTensor;
}
static cudnn_frontend::Tensor
createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim,
maskVal_stride, false, true); // is by value
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// create causal mask (row >= col)
auto causalMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 107, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the greater than equal to comparison descriptor
auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE);
// Create a greater than equal to node
auto rowGreaterCol_op = binary_pw_op_create(
rowIndexTensor, columnIndexTensor, causalMaskTensor, rowGreaterColDesc);
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select node
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor,
causalMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(rowGreaterCol_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createSoftmaxForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, bool isTraining,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& sAfterMaskTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
// max (x)
auto afterMaxReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// x - max(x)
auto afterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x)))
auto afterAddReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// log (sum (e^(x - max(x))))
auto afterLogLTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 154, afterReduction_dim,
afterReduction_stride, true, false);
// M + log (sum (e^(x - max(x))))
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, afterReduction_dim,
afterReduction_stride, !isTraining, false);
// not virtual if training is true, virtual if training is false
// divide (e/ sum(e))
auto afterSoftmaxTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 156)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(sAfterMaskTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract node
auto subtract_op = binary_pw_op_create(
sAfterMaskTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent node
auto exponent_op = unary_pw_op_create(
afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(afterAddReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Create log descriptor
auto logDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_LOG);
// Create log node
auto log_op = unary_pw_op_create(afterAddReductionTensor, afterLogLTensor, logDesc);
// Create add descriptor
auto addDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD);
// Create add node
auto add_op = binary_pw_op_create(
afterMaxReductionTensor, afterLogLTensor,
softmaxStatsTensor, addDesc);
// Define the division descriptor
auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV);
// Create a subtract node
auto division_op = binary_pw_op_create(
afterExponentTensor, afterAddReductionTensor,
afterSoftmaxTensor, divisionDesc);
ops->push_back(std::move(reductionMax_op));
ops->push_back(std::move(subtract_op));
ops->push_back(std::move(exponent_op));
ops->push_back(std::move(reductionAdd_op));
ops->push_back(std::move(log_op));
ops->push_back(std::move(add_op));
ops->push_back(std::move(division_op));
return afterSoftmaxTensor;
}
static cudnn_frontend::Tensor
createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// mask for the dropout
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor
createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static void
createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
cudnn_frontend::Tensor const &afterScaleDropoutTensor) {
NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one");
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
// Define the matmul 2 desc
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a matmul 2 node
auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(afterScaleDropoutTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops->push_back(std::move(matmul_op2));
}
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout,
void *devPtrQ, void *devPtrK, void *devPtrV,
void *devPtrSoftmaxStats, void *devPtrO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) {
dropout_probability == 0.0f;
}
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
is_training, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
}
// otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Q * K^T
auto sTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops);
// Q * K^T * bmmScale
auto sScaleTensor = createScale(
b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops);
// Causual mask
auto sAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
NVTE_CHECK(dropout_probability != 1.0f,
"Dropout probability cannot be 1.0");
auto softmax_output = createSoftmaxForward(
b, h, s_q, s_kv, is_training, &ops, sAfterMaskTensor);
// Dropout(softmax)
auto dropout_output = createDropoutForward(
b, h, s_q, s_kv, d,
dropout_probability, tensorType, &ops, softmax_output);
createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dropout_output);
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig,
filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_fprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) {
*workspace_size = plan_workspace_size;
return;
}
std::set<std::pair<uint64_t, void*>> data_ptrs;
// Add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
// If training mode, we write out softmax stats
if (is_training) {
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
void* devPtrO, void* devPtrSoftmaxStats,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
true, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
auto it = cache.find(descriptor);
if (it != cache.end()) {
return it->second;
}
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, q_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_transpose_dim[4] = {b, h, d, s_kv};
int64_t k_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t v_transpose_dim[4] = {b, h, d, s_kv};
int64_t v_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, v_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose);
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, p_stride,
layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t p_transpose_stride[4];
p_transpose_stride[0] = p_stride[0];
p_transpose_stride[1] = p_stride[1];
p_transpose_stride[2] = p_stride[3];
p_transpose_stride[3] = p_stride[2];
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
/*******************************************************************************
* Dot product dO * O */
// output and gradient of the output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
auto dOTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto dotProductTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID, o_dim,
o_stride, true, false); // is virtual
// Create pointwise mul
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// do * O
auto dotProductOp = binary_pw_op_create(
dOTensor, oTensor, dotProductTensor, multiplyDesc);
ops.push_back(std::move(dotProductOp));
/*******************************************************************************
* Reduction(dO * O) */
int64_t reduction_dim[4] = {b, h, s_q, 1};
int64_t reduction_stride[4] = {h * s_q, s_q, 1, 1};
// reduction(dO * O)
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
ops.push_back(std::move(reductionMax_op));
/*******************************************************************************
* reduction(dO * O) * scale prob -> softmaxSum */
auto softmaxSumTensor = tensor_create(
CUDNN_DATA_FLOAT, S_SUM_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto scaleProbTensor = tensor_create(
CUDNN_DATA_FLOAT, SCALE_PROB, scale_dim,
scale_stride, false, true); // not virtual
auto softmaxSumOp = binary_pw_op_create(
afterReductionTensor, scaleProbTensor,
softmaxSumTensor, multiplyDesc);
ops.push_back(std::move(softmaxSumOp));
/*******************************************************************************
* Q @ K.T -> P */
// Inputs from fprop
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_transpose_dim,
k_transpose_stride, false, false);
auto pTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2, p_dim,
p_stride, true, false); // is virtual
// matmul to calculate dvTensor
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op0 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(pTensor)
.setmatmulDesc(matmul_0_Desc)
.build();
ops.push_back(std::move(matmul_op0));
/*******************************************************************************
* P * bmmScale -> pAfterScale */
auto bmmScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim,
scale_stride, false, true); // not virtual and by value
auto pAfterScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2000, p_dim,
p_stride, true, false); // virtual
auto scaleOp = binary_pw_op_create(
pTensor, bmmScaleTensor, pAfterScaleTensor, multiplyDesc);
ops.push_back(std::move(scaleOp));
/*******************************************************************************
* Causal masking -> pAfterMaskTensor */
auto pAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
/*******************************************************************************
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
auto pAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 3, p_dim,
p_stride, true, false); // is virtual
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto subtract_op = binary_pw_op_create(
pAfterMaskTensor, softmaxStatsTensor,
pAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op));
/*******************************************************************************
* e^(pAfterSubtract) -> pAfterSoftmax */
auto pAfterSoftmaxTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 4, p_dim,
p_stride, true, false); // is virtual
auto expDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
auto exp_op = unary_pw_op_create(
pAfterSubtractTensor, pAfterSoftmaxTensor, expDesc);
ops.push_back(std::move(exp_op));
/*******************************************************************************
* Dropout -> afterScaleDropout */
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 5, p_dim,
p_stride, true, false); // is virtual
auto afterScaleDropoutTensor = createDropoutBackward(
b, h, s_q, s_kv, d, dropout_probability, tensorType,
&ops, pAfterSoftmaxTensor, dropoutMaskTensor);
/*******************************************************************************
* afterScaleDropout -> sTransposeTensor */
auto sTransposeTensor = tensor_create(
tensorType, VIRTUAL_ID + 6, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(afterScaleDropoutTensor)
.setyDesc(sTransposeTensor)
.build();
ops.push_back(std::move(reshape_op));
// Outputs of bprop
int64_t dqkv_dim[4] = {b, h, s_kv, d};
int64_t dqkv_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, dqkv_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dqkv_dim, dqkv_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dqkv_dim, dqkv_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dqkv_dim, dqkv_stride, false, false);
// not virtual
/*******************************************************************************
* sTransposeTensor @ dO -> dV */
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op1 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor)
.setcMatDesc(dVTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops.push_back(std::move(matmul_op1));
/*******************************************************************************
* dO @ V.T -> dS */
auto vTransposeTensor = tensor_create(
tensorType, V_ID, v_transpose_dim,
v_transpose_stride, false, false);
auto dSTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 7, p_dim,
p_stride, true, false); // is virtual
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor)
.setcMatDesc(dSTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops.push_back(std::move(matmul_op2));
/*******************************************************************************
* dS * dropoutMask -> dSAfterDropout */
auto dSAfterDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 8, p_dim,
p_stride, true, false); // is virtual
auto multiply_op = binary_pw_op_create(
dSTensor, dropoutMaskTensor,
dSAfterDropoutTensor, multiplyDesc);
ops.push_back(std::move(multiply_op));
/*******************************************************************************
* dSAfterDropout - softmaxSum -> dsAfterSubtract */
auto dsAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 9, p_dim,
p_stride, true, false); // is virtual
auto subtract_op2 = binary_pw_op_create(
dSAfterDropoutTensor, softmaxSumTensor,
dsAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op2));
/*******************************************************************************
* dsAfterSubtract * afterSoftmax -> dP */
auto dPTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 10, p_dim,
p_stride, true, false); // is virtual
auto multiply_op2 = binary_pw_op_create(
dsAfterSubtractTensor, pAfterSoftmaxTensor,
dPTensor, multiplyDesc);
ops.push_back(std::move(multiply_op2));
/*******************************************************************************
* dP * scaleDropout -> dPAfterDropoutScale */
auto dPAfterDropoutScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 11, p_dim,
p_stride, true, false); // is virtual
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto multiply_op3 = binary_pw_op_create(
dPTensor, scaleDropoutTensor,
dPAfterDropoutScaleTensor, multiplyDesc);
ops.push_back(std::move(multiply_op3));
/*******************************************************************************
* dPAfterDropoutScale * bmmScale -> dPScaledTensor */
auto dPScaledTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 12, p_dim,
p_stride, true, false); // is virtual
auto multiply_op4 = binary_pw_op_create(
dPAfterDropoutScaleTensor, bmmScaleTensor,
dPScaledTensor, multiplyDesc);
ops.push_back(std::move(multiply_op4));
/*******************************************************************************
* K.T -> K */
int64_t kDim[4] = {b, h, s_kv, d};
int64_t kStride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, kStride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
auto kTensor = tensor_create(
tensorType, VIRTUAL_ID + 13, kDim,
kStride, true, false); // is virtual
auto reshape_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(kTransposeTensor)
.setyDesc(kTensor)
.build();
ops.push_back(std::move(reshape_op2));
/*******************************************************************************
* dP @ K -> dqAccumTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqkv_dim)
.setStride(4, dqkv_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(false)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dqAccumTensor)
.setmatmulDesc(matmul_3_Desc)
.build();
ops.push_back(std::move(matmul_op3));
/*******************************************************************************
* dP.T @ Q -> dK */
auto dPTransposeTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 14, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dPTensor)
.setyDesc(dPTransposeTensor)
.build();
ops.push_back(std::move(reshape_op3));
auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
auto matmul_op4 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTransposeTensor)
.setbMatDesc(qTensor)
.setcMatDesc(dKTensor)
.setmatmulDesc(matmul_4_Desc)
.build();
ops.push_back(std::move(matmul_op4));
/*******************************************************************************
* dqAccumTensor @ identity -> dqTensor */
auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY);
auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc);
ops.push_back(std::move(identity_op));
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr, CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_bprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_bprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size
+ dqAccum_workspace_size;
return;
}
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size));
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator));
data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV));
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrKTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrVTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(dO_ID, devPtrdO));
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
data_ptrs.insert(std::pair<uint64_t, void*>(S_SUM_ID, devPtrSoftmaxSum));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
float scaleProb = 1.0f - dropout_probability;
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(SCALE_PROB, &scaleProb));
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
}
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
} // namespace transformer_engine
#endif // CUDNN_VERSION >= 8900
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_arbitrary_seqlen.h
* \brief Functions for fused attention with seqlen > 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_f16_max512_seqlen.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
......@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked(
void *devPtrS = nullptr;
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
......@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
......@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
......@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
......
......@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......
......@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM(
}
// fused attention FWD FP8
void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
......@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
}
// fused attention BWD FP8
void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
......@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
......@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
......@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked(
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_Output_Tensors->size == 0) {
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_Output_Tensors->size = 2;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_Output_Tensors->size == 2) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
}
void* devPtrAmaxS = input_output_S->amax.dptr;
......@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_fwd_fp8(
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
......@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked(
}
}
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_bwd_fp8(
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
......
......@@ -13,7 +13,7 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
......@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
......@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked(
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......
......@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -94,6 +94,38 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2,
};
enum NVTE_Fused_Attn_Backend {
/*! No supported backend */
NVTE_No_Backend = -1,
/*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */
NVTE_F16_max512_seqlen = 0,
/*! cuDNN-based FP16/BF16 fused attention for any sequence length */
NVTE_F16_arbitrary_seqlen = 1,
/*! cuDNN-based FP8 fused attention for <= 512 sequence length */
NVTE_FP8 = 2,
};
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input.
*
* Computes:
......@@ -104,9 +136,10 @@ enum NVTE_Mask_Type {
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -114,11 +147,12 @@ enum NVTE_Mask_Type {
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing.
* It may be >= max(cu_seqlens).
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
......@@ -133,7 +167,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
......@@ -147,9 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -158,12 +193,13 @@ void nvte_fused_attn_fwd_qkvpacked(
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing.
* It may be >= max(cu_seqlens).
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -199,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -208,14 +244,15 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing
* for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing
* for KV. It may be >= max(cu_seqlens_kv).
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
......@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
......@@ -246,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -256,16 +293,17 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing
* for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing
* for KV. It may be >= max(cu_seqlens_kv).
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......
......@@ -15,6 +15,16 @@ import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
divide,
......@@ -26,6 +36,7 @@ from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
dist_group_type,
TE_DType,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
......@@ -272,9 +283,9 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv
def _check_if_interleaved(q, k, v):
data_ptr = q.storage().data_ptr()
check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v])
def _check_if_interleaved_qkv(q, k, v):
data_ptr = q.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs:
return False
......@@ -293,9 +304,32 @@ def _check_if_interleaved(q, k, v):
for i, x in enumerate([q, k, v]))
return check_offsets
def _check_if_interleaved_kv(k, v):
data_ptr = k.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
if not check_ptrs:
return False
stride = k.stride()
check_strides = all(stride == x.stride() for x in [k, v])
if not check_strides:
return False
shape = k.shape
check_shapes = all(shape == x.shape for x in [k, v])
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v]))
return check_offsets
class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package.
"""Dot product attention, using HazyResearch flash-attn package:
https://github.com/HazyResearch/flash-attention
"""
def __init__(
......@@ -326,9 +360,9 @@ class FlashAttention(torch.nn.Module):
"""flash-attn fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16]
and value_layer.dtype in [torch.float16, torch.bfloat16]
), 'FlashAttention currently only supports FP16 and BF16.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
......@@ -338,7 +372,7 @@ class FlashAttention(torch.nn.Module):
if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved(query_layer, key_layer, value_layer)):
_check_if_interleaved_qkv(query_layer, key_layer, value_layer)):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer,
value_layer)
......@@ -374,6 +408,286 @@ class FlashAttention(torch.nn.Module):
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input"""
@staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(qkv, out, cu_seqlens)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
qkv, out, cu_seqlens = ctx.saved_tensors
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (None, None, None, dqkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, dqkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttnFunc_kvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed KV input"""
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"]
cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
2. FusedAttnBackend["F16_arbitrary_seqlen"]
cuDNN based fused attention for FP16/BF16 and any sequence length.
Support matrix:
| backend | 1 | 2 |
| flash based | no | yes |
| cuDNN based | yes | yes |
| qkv dtype | fp16/bf16 | fp16/bf16 |
| attn_type | self/cross | self |
| qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | |
| mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias |
| dropout | yes | yes |
| max_seqlen | <=512 | any |
| head_dim | 64 | 64,128 |
| output dtype | fp16/bf16 | fp16/bf16 |
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
attention_type: str = "self",
) -> None:
super().__init__()
self.norm_factor = norm_factor
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
), 'FusedAttention only supports FP16 and BF16 data types.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FusedAttention only supports CUDA tensors.'
qkv_dtype = TE_DType[query_layer.dtype]
seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1]
seqlen_kv = key_layer.shape[0]
max_seqlen_q = seqlen_q
max_seqlen_kv = seqlen_kv
if self.attention_type == "self":
if _check_if_interleaved_qkv(query_layer, key_layer, value_layer):
query_layer = query_layer.unsqueeze(3)
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
# [s, b, h, 3, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 3)
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous()
else:
query_layer = query_layer.unsqueeze(2)
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
# [s, b, 3, h, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 2)
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(0, 1).contiguous()
# [total_seqs, 3, h, d]
mixed_layer = mixed_layer.view(
mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:]).contiguous()
qkv_layout = "qkv_interleaved"
max_seqlen = seqlen_q
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
with self.attention_dropout_ctx():
output = FusedAttnFunc_qkvpacked.apply(
self.training,
max_seqlen,
cu_seqlens,
mixed_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
)
output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()
if self.attention_type == "cross":
if _check_if_interleaved_kv(key_layer, value_layer):
# [s, b, h, 2, d]
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
key_value = torch.cat([key_layer, value_layer], dim = 3)
# [b, s, 2, h, d]
key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous()
else:
# [s, b, 2, h, d]
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
key_value = torch.cat([key_layer, value_layer], dim = 2)
# [b, s, 2, h, d]
key_value = key_value.transpose(0, 1).contiguous()
# [total_seqs, 2, h, d]
query_layer = query_layer.transpose(0, 1).contiguous()
query_layer = query_layer.view(
query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:])
key_value = key_value.view([key_value.shape[0] * key_value.shape[1]]
+ key_value.shape[2:]).contiguous()
qkv_layout = "kv_interleaved"
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * seqlen_kv,
step=seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
with self.attention_dropout_ctx():
outputs = FusedAttnFunc_kvpacked.apply(
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
query_layer, key_value,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
)
output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(),
outputs[1].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous())
return output
class DotProductAttention(torch.nn.Module):
"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
......@@ -427,15 +741,16 @@ class DotProductAttention(torch.nn.Module):
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
attention_type: str = "self",
) -> None:
super().__init__()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_size)
self.hidden_size_per_partition = divide(projection_size, self.tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
......@@ -452,18 +767,28 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= 8.0
)
self.use_fused_attention = (
int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= 8.0
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type,
}
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn
# Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(
norm_factor, **attn_kwargs,
attention_type = attention_type)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
......@@ -494,6 +819,9 @@ class DotProductAttention(torch.nn.Module):
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -511,6 +839,17 @@ class DotProductAttention(torch.nn.Module):
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
.. note::
`DotProductAttention` supports three backends: 1) `FlashAttention` which calls
HazyResearch's FlashAttention PyTorch API, 2) `FusedAttention` which has multiple
fused attention implementations as its backends (see `FusedAttention` for
more details), and 3) `UnfusedDotProductAttention` which is the native PyTorch
implementation with fused scaled masked softmax. Users can use environment variables
`NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, and `NVTE_FUSED_ATTN_BACKEND` to control
which DotProductAttention backend, and FusedAttention backend if applicable, to use.
The default DotProductAttention backend is 1.
Parameters
----------
query_layer : torch.Tensor
......@@ -526,9 +865,17 @@ class DotProductAttention(torch.nn.Module):
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, defautl = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
......@@ -538,9 +885,26 @@ class DotProductAttention(torch.nn.Module):
if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
use_fused_attention = False
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved"
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
self.attention_dropout,
query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1])
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = use_fused_attention and is_backend_avail
if use_flash_attention:
if checkpoint_core_attention:
......@@ -550,6 +914,22 @@ class DotProductAttention(torch.nn.Module):
value_layer)
return self.flash_attention(query_layer, key_layer, value_layer)
if use_fused_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention,
query_layer,
key_layer,
value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -752,6 +1132,9 @@ class MultiHeadAttention(torch.nn.Module):
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
......@@ -952,7 +1335,10 @@ class MultiHeadAttention(torch.nn.Module):
key_layer,
value_layer,
attention_mask,
checkpoint_core_attention=checkpoint_core_attention,
checkpoint_core_attention = checkpoint_core_attention,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill,
)
# =================
......
......@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding")
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
......
......@@ -7,6 +7,12 @@ import math
from typing import Tuple, List, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine_extensions import (
NVTE_QKV_Layout,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Fused_Attn_Backend
)
__all__ = ['fused_attn_fwd_qkvpacked',
......@@ -24,6 +30,34 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32,
}
QKVLayout = {
"not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
}
AttnBiasType = {
"no_bias": NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
}
AttnMaskType = {
"no_mask": NVTE_Mask_Type.NVTE_NO_MASK,
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
"FP8": NVTE_Fused_Attn_Backend.NVTE_FP8,
"No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend,
}
BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
def check_tensor(x: torch.Tensor):
"""Check tensor properties."""
......@@ -109,7 +143,8 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens: torch.Tensor,
qkv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
......@@ -117,9 +152,9 @@ def fused_attn_fwd_qkvpacked(
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -139,8 +174,10 @@ def fused_attn_fwd_qkvpacked(
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
......@@ -158,12 +195,12 @@ def fused_attn_fwd_qkvpacked(
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
......@@ -178,15 +215,26 @@ def fused_attn_fwd_qkvpacked(
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None
softmax-related tensors:
1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
softmax: torch.Tensor
Softmax(Q*K.T)
shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32
2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
softmaxStats: torch.Tensor
log(sum(e^(x - max(x)))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
3. if fused_attention_backend == FusedAttnBackend["FP8"]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
"""
......@@ -203,60 +251,58 @@ def fused_attn_fwd_qkvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen, max_seqlen]
), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (bias.dtype == qkv.dtype
), "bias tensor must be in the same dtype as qkv."
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_o is not None), "q_scale_o is required for the FP8 API."
assert (amax_s is not None), "amax_s is required for the FP8 API."
assert (amax_o is not None), "amax_o is required for the FP8 API."
if attn_bias_type != "no_bias":
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen]
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen * max_seqlen
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (max_seqlen * max_seqlen
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(q_scale_s)
check_scalar(q_scale_o)
check_scalar(amax_s)
check_scalar(amax_o)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread,
)
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:]
......@@ -267,7 +313,8 @@ def fused_attn_bwd_qkvpacked(
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -279,9 +326,9 @@ def fused_attn_bwd_qkvpacked(
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input.
......@@ -306,6 +353,8 @@ def fused_attn_bwd_qkvpacked(
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -330,12 +379,12 @@ def fused_attn_bwd_qkvpacked(
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
......@@ -345,8 +394,8 @@ def fused_attn_bwd_qkvpacked(
d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens)
......@@ -363,29 +412,27 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64:
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (d_scale_s is not None), "d_scale_s is required for the FP8 API."
assert (d_scale_o is not None), "d_scale_o is required for the FP8 API."
assert (d_scale_do is not None), "d_scale_do is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API."
assert (amax_dp is not None), "amax_dp is required for the FP8 API."
assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API."
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
assert (amax_dp is not None), "amax_dp is required for FP8 fused attention."
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API."
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
......@@ -399,37 +446,21 @@ def fused_attn_bwd_qkvpacked(
check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
if bias_type == "no_bias":
# return d_qkv when bias_type is no_bias
return output_tensors[0]
# otherwise return (d_qkv, d_bias)
if attn_bias_type == "no_bias":
# return d_qkv when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_qkv, d_bias)
return output_tensors[0], output_tensors[1]
def fused_attn_fwd_kvpacked(
......@@ -441,7 +472,8 @@ def fused_attn_fwd_kvpacked(
q: torch.Tensor,
kv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
......@@ -449,9 +481,9 @@ def fused_attn_fwd_kvpacked(
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -479,8 +511,10 @@ def fused_attn_fwd_kvpacked(
where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType
data type of Q and KV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
......@@ -498,12 +532,12 @@ def fused_attn_fwd_kvpacked(
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
......@@ -518,15 +552,26 @@ def fused_attn_fwd_kvpacked(
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None
softmax-related tensors:
1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
softmax: torch.Tensor
Softmax(Q*K.T)
shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32
2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
softmaxStats: torch.Tensor
log(sum(e^(x - max(x)))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
3. if fused_attention_backend == FusedAttnBackend["FP8"]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
"""
......@@ -551,49 +596,42 @@ def fused_attn_fwd_kvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (bias.dtype == q.dtype
), "bias tensor must be in the same dtype as q and kv."
if attn_bias_type != "no_bias":
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and (d == 64):
assert False, "The FP8 fused attention API currently only supports packed QKV input."
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
else:
assert False, "No support for this dtype and max_seqlen combination."
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:]
......@@ -607,7 +645,8 @@ def fused_attn_bwd_kvpacked(
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -619,9 +658,9 @@ def fused_attn_bwd_kvpacked(
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
......@@ -654,6 +693,8 @@ def fused_attn_bwd_kvpacked(
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -679,12 +720,12 @@ def fused_attn_bwd_kvpacked(
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
......@@ -696,8 +737,8 @@ def fused_attn_bwd_kvpacked(
d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens_q)
......@@ -722,45 +763,52 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and d == 64:
assert False, "The FP8 fused attention API currently only supports packed QKV input."
############### BF16/FP16 fused attention API from fmha_v2 ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
############### BF16/FP16 fused attention API from fmha_v1 apex ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
assert (amax_dp is not None), "amax_dp is required for FP8 fused attention."
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
# returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias)
if bias_type == "no_bias":
return output_tensors[:2]
if attn_bias_type == "no_bias":
# return (d_q, d_kv) when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_q, d_kv), d_bias
return output_tensors[:2], output_tensors[2]
......@@ -58,7 +58,10 @@ enum FP8FwdTensors {
GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
GEMM2_OUTPUT = 5,
GEMM3_INPUT = 6,
GEMM3_WEIGHT = 7,
GEMM3_OUTPUT = 8
};
// Used as named indices on the `scale`, `scale_inv`,
......@@ -67,7 +70,9 @@ enum FP8BwdTensors {
GRAD_OUTPUT1 = 0,
GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3
GRAD_INPUT2 = 3,
GRAD_OUTPUT3 = 4,
GRAD_INPUT3 = 5
};
......@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt32:
return torch::kInt32;
case transformer_engine::DType::kInt64:
return torch::kInt64;
case transformer_engine::DType::kFloat32:
return at::kFloat;
case transformer_engine::DType::kFloat16:
......
......@@ -12,43 +12,21 @@
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) {
if (qkv_layout == "not_interleaved") {
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED;
} else if (qkv_layout == "qkv_interleaved") {
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
} else if (qkv_layout == "kv_interleaved") {
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
} else {
NVTE_ERROR("Invalid QKV layout. \n");
}
}
// convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type,
p_dropout, max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend;
}
// fast zero-fills of tensors
......@@ -103,10 +81,8 @@ __global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(
at::CUDAGeneratorImpl* gen,
size_t max_seq_len,
size_t threads_per_cta) {
size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
......@@ -117,7 +93,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
......@@ -127,15 +103,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
// construct NVTE tensors
......@@ -166,7 +145,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && (Bias.has_value())) {
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
......@@ -175,23 +154,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta);
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
......@@ -209,7 +181,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -219,10 +191,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
// output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
......@@ -230,9 +201,6 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
......@@ -245,14 +213,14 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors;
}
......@@ -261,7 +229,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
......@@ -281,13 +249,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
if (set_zero) {
auto max_tokens = dQKV.size(0);
auto self_2d = dQKV.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
......@@ -341,13 +314,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
......@@ -380,7 +347,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -403,7 +370,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -419,7 +386,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......@@ -431,15 +398,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
// construct NVTE tensors
......@@ -474,7 +444,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && (Bias.has_value())) {
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
......@@ -485,24 +455,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(
gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta);
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
......@@ -522,7 +484,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -532,10 +494,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
// output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
......@@ -543,9 +504,6 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
......@@ -560,14 +518,14 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors;
}
......@@ -577,7 +535,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......@@ -600,14 +558,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
// create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
if (set_zero) {
auto max_tokens_q = dQ.size(0);
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQ.fill_(0);
dKV.fill_(0);
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
......@@ -674,13 +641,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward to NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
......@@ -711,7 +672,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -737,7 +698,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
......@@ -2227,6 +2188,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
......@@ -2279,11 +2241,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT);
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2);
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
}
......@@ -7,17 +7,22 @@
#include "common.h"
#include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type);
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
......@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
......@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......
......@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
hidden_states = hidden_states.contiguous()
......@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module):
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment