Unverified Commit 989a53a0 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add FP8 fused attention (#155)



* Add FP8 fused attention to TE for PyTorch
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add c api docs for fused attention
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add exception for unsupported precision/sequence length combinations
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix installation requirement for non fused attn use cases
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix docs for fused-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes based on PR comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description for kvpacked fwd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description of Bias in C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes for cudnn requirement and description for QKV tensors
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix QKV layout description and support matrix for C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix typo precision
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c3407300
......@@ -24,11 +24,12 @@ extern "C" {
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEFloat32 = 2, /*!< 32-bit float */
kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */
kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */
};
......@@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor);
*/
float *nvte_tensor_scale_inv(const NVTETensor tensor);
struct NVTETensorPack {
static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */
NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */
size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */
};
/*! \brief Create NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_create(NVTETensorPack* pack);
/*! \brief Destroy NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_destroy(NVTETensorPack* pack);
#ifdef __cplusplus
} // extern "C"
......@@ -146,11 +160,12 @@ namespace transformer_engine {
enum class DType {
kByte = 0,
kInt32 = 1,
kFloat32 = 2,
kFloat16 = 3,
kBFloat16 = 4,
kFloat8E4M3 = 5,
kFloat8E5M2 = 6,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kNumTypes
};
......
......@@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) {
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr);
}
void nvte_tensor_pack_create(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]);
delete t;
}
}
......@@ -14,7 +14,7 @@ extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
torch.int8: tex.DType.kByte,
torch.uint8: tex.DType.kByte,
torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16,
......
......@@ -3,11 +3,735 @@
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union
import math
from typing import Optional, Tuple, List, Union
import torch
import transformer_engine_extensions as tex
from .constants import TE_DType
TORCH_DType = {
tex.DType.kFloat8E4M3: torch.uint8,
tex.DType.kFloat8E5M2: torch.uint8,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
tex.DType.kFloat32: torch.float32,
tex.DType.kInt32: torch.int32,
}
def check_tensor(x: torch.Tensor):
"""Check tensor properties."""
assert (x.is_cuda and x.is_contiguous()
), "Tensor should be a GPU tensor and contiguous."
def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(qkv)
assert (qkv.dtype is dtype
and qkv.dim() == 4
and qkv.shape[1] == 3
), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
and {dtype} dtype."""
def check_q(q: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(q)
assert (q.dtype is dtype
and q.dim() == 3
), """Q should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_kv(kv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(kv)
assert (kv.dtype is dtype
and kv.dim() == 4
and kv.shape[1] == 2
), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
and {dtype} dtype."""
def check_o(o: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(o)
assert (o.dtype is dtype
and o.dim() == 3
), """O and dO should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
"""Check tensor properties."""
check_tensor(stats)
assert (stats.dtype is torch.float32
and stats.dim() == 4
and stats.shape == torch.Size([b, h, s, 1])
), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
shape and float32 dtype."""
def check_cu_seqlens(cu_seqlens: torch.Tensor):
"""Check tensor properties."""
check_tensor(cu_seqlens)
assert (cu_seqlens.dtype is torch.int32
and cu_seqlens.dim() == 1
), """cu_seqlens should be in [batch_size +1] shape and int32 dtype."""
def check_scalar(scalar: torch.Tensor):
"""Check tensor properties."""
check_tensor(scalar)
assert (scalar.dtype is torch.float32
and scalar.dim() <= 1
and scalar.numel() == 1
), "amax/scale/descale tensors should be scalars in float32 dtype."
def check_rng_state(rng_state: torch.Tensor):
"""Check tensor properties."""
check_tensor(rng_state)
assert (rng_state.dtype is torch.int64
and rng_state.numel() == 2
), "rng_state should be [seed, offset] and in int64 dtype."
def fused_attn_fwd_qkvpacked(
is_training: bool,
max_seqlen: int,
cu_seqlens: torch.Tensor,
qkv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed QKV input.
Parameters
----------
is_training: bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens)
cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations
amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations
amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
Returns
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
state of the random number generator;
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_o is not None), "q_scale_o is required for the FP8 API."
assert (amax_s is not None), "amax_s is required for the FP8 API."
assert (amax_o is not None), "amax_o is required for the FP8 API."
check_scalar(d_scale_qkv)
check_scalar(q_scale_s)
check_scalar(q_scale_o)
check_scalar(amax_s)
check_scalar(amax_o)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
)
return output_tensors[0], output_tensors[1:]
def fused_attn_bwd_qkvpacked(
max_seqlen: int,
cu_seqlens: torch.Tensor,
qkv: torch.Tensor,
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
amax_dp: torch.Tensor = None,
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input.
Parameters
----------
max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q)
cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
d_bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_o: torch.Tensor, default = None
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
input tensor for the quantization of dP in FP8 computations, P = Q * K.T
q_scale_dqkv: torch.Tensor, default = None
input tensor for the quantization of dQKV in FP8 computations
amax_dp: torch.Tensor, default = None
output tensor, amax of dP, used by the next iteration in FP8 computations
amax_dqkv: torch.Tensor, default = None
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
Returns
----------
d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64:
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (d_scale_s is not None), "d_scale_s is required for the FP8 API."
assert (d_scale_o is not None), "d_scale_o is required for the FP8 API."
assert (d_scale_do is not None), "d_scale_do is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API."
assert (amax_dp is not None), "amax_dp is required for the FP8 API."
assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
return output_tensors[0]
def fused_attn_fwd_kvpacked(
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
kv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed KV input.
Parameters
----------
is_training: bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations
amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations
amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
Returns
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
state of the random number generator;
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = kv.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and (d == 64):
assert False, "The FP8 fused attention API currently only supports packed QKV input."
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
)
return output_tensors[0], output_tensors[1:]
def fused_attn_bwd_kvpacked(
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
kv: torch.Tensor,
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
amax_dp: torch.Tensor = None,
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
Parameters
----------
max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_o: torch.Tensor, default = None
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
input tensor for the quantization of dP in FP8 computations, P = Q * K.T
q_scale_dqkv: torch.Tensor, default = None
input tensor for the quantization of dQKV in FP8 computations
amax_dp: torch.Tensor, default = None
output tensor, amax of dP, used by the next iteration in FP8 computations,
P = Q * K.T
amax_dqkv: torch.Tensor, default = None
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
Returns
----------
d_q: torch.Tensor
gradient tensor of Q; same data type and shape as Q
d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = q.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and d == 64:
assert False, "The FP8 fused attention API currently only supports packed QKV input."
############### BF16/FP16 fused attention API from fmha_v2 ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
############### BF16/FP16 fused attention API from fmha_v1 apex ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
return output_tensors
def fp8_gemm(
A: torch.Tensor,
......@@ -233,9 +957,9 @@ def fp8_cast_transpose_fused(
return_outputs = False
if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.int8)
cast_out = torch.empty_like(inp, dtype=torch.uint8)
transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
)
return_outputs = True
......
......@@ -88,6 +88,19 @@ size_t product(const std::vector<size_t> &shape) {
}
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros) {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
c10::IntArrayRef ar_shape(shape_int64);
if (init_to_zeros) {
return at::zeros(ar_shape, at::CUDA(GetATenDType(type)));
} else {
return at::empty(ar_shape, at::CUDA(GetATenDType(type)));
}
}
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
bool init_to_zeros) {
......
......@@ -15,9 +15,15 @@
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/fused_attn.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cuda.h>
......@@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
case torch::kByte:
return transformer_engine::DType::kByte;
case torch::kInt32:
return transformer_engine::DType::kInt32;
case torch::kInt64:
return transformer_engine::DType::kInt64;
default:
NVTE_ERROR("Invalid type");
}
......@@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
size_t product(const std::vector<size_t> &shape);
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros);
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
......
......@@ -9,6 +9,742 @@
#include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) {
if (qkv_layout == "not_interleaved") {
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED;
} else if (qkv_layout == "qkv_interleaved") {
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
} else if (qkv_layout == "kv_interleaved") {
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
} else {
NVTE_ERROR("Invalid QKV layout. \n");
}
}
// convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
}
// fast zero-fills of tensors
template <typename scalar_t>
__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor,
const int32_t* const start_row,
const size_t num_rows) {
size_t row_stride = gridDim.y * blockDim.x;
size_t row_index = blockIdx.x + static_cast<size_t>(start_row[0]);
size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;
while (row_index < num_rows) {
out_tensor[row_index*row_stride + col_index] = 0;
row_index += gridDim.x;
}
}
// fast zero-fills of tensors
void mha_fill(const at::Tensor &self, const at::Tensor &start_index) {
auto max_tokens = self.size(0);
auto self_2d = self.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
TORCH_CHECK(self.is_contiguous(), "input not contiguous");
TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);
uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y);
dim3 dim_grid(num_blk_x, num_blk_y);
dim3 dim_block(block_size);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
self_2d.scalar_type(), "mha_fill", [&]() {
mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_2d.data_ptr<scalar_t>(),
static_cast<int32_t*>(start_index.data_ptr()),
max_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
// extract seed and offset from PhiloxCudaState
__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
if (arg.captured_) {
rng_state_ptr[0] = static_cast<int64_t>(*arg.seed_.ptr);
rng_state_ptr[1] = static_cast<int64_t>(
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
rng_state_ptr[0] = static_cast<int64_t>(arg.seed_.val);
rng_state_ptr[1] = static_cast<int64_t>(arg.offset_.val);
}
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(
at::CUDAGeneratorImpl* gen,
size_t max_seq_len,
size_t threads_per_cta) {
at::PhiloxCudaState philox_args;
size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
// fused attention FWD with packed QKV
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
return output_tensors;
}
// fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine;
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
if (set_zero) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQKV};
}
// fused attention FWD with packed KV
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(
gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
return output_tensors;
}
// fused attention BWD with packed KV
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine;
// create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
if (set_zero) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward to NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQ, dKV};
}
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
......@@ -749,13 +1485,13 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
......@@ -795,12 +1531,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype,
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output);
......@@ -1066,6 +1802,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV");
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV");
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
......
......@@ -5,7 +5,95 @@
************************************************************************/
#include "common.h"
#include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type);
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
......
......@@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor:
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda"
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
......@@ -520,7 +520,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.int8,
dtype=torch.uint8,
),
)
setattr(
......@@ -530,7 +530,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.int8,
dtype=torch.uint8,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment