Unverified Commit 989a53a0 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add FP8 fused attention (#155)



* Add FP8 fused attention to TE for PyTorch
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add c api docs for fused attention
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add exception for unsupported precision/sequence length combinations
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix installation requirement for non fused attn use cases
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix docs for fused-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes based on PR comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description for kvpacked fwd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description of Bias in C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes for cudnn requirement and description for QKV tensors
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix QKV layout description and support matrix for C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix typo precision
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c3407300
...@@ -24,11 +24,12 @@ extern "C" { ...@@ -24,11 +24,12 @@ extern "C" {
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */ kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEFloat32 = 2, /*!< 32-bit float */ kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */ kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */ kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */ kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */ kNVTENumTypes /*!< Number of supported types */
}; };
...@@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor); ...@@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor);
*/ */
float *nvte_tensor_scale_inv(const NVTETensor tensor); float *nvte_tensor_scale_inv(const NVTETensor tensor);
struct NVTETensorPack {
static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */
NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */
size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */
};
/*! \brief Create NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_create(NVTETensorPack* pack);
/*! \brief Destroy NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_destroy(NVTETensorPack* pack);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
...@@ -146,11 +160,12 @@ namespace transformer_engine { ...@@ -146,11 +160,12 @@ namespace transformer_engine {
enum class DType { enum class DType {
kByte = 0, kByte = 0,
kInt32 = 1, kInt32 = 1,
kFloat32 = 2, kInt64 = 2,
kFloat16 = 3, kFloat32 = 3,
kBFloat16 = 4, kFloat16 = 4,
kFloat8E4M3 = 5, kBFloat16 = 5,
kFloat8E5M2 = 6, kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kNumTypes kNumTypes
}; };
......
...@@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) { ...@@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) {
"Tensor's inverse of scale must have Float32 type!"); "Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr); return reinterpret_cast<float*>(t.scale_inv.dptr);
} }
void nvte_tensor_pack_create(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]);
delete t;
}
}
...@@ -14,7 +14,7 @@ extension. Has one to one mapping ...@@ -14,7 +14,7 @@ extension. Has one to one mapping
with enum in transformer_engine.h with enum in transformer_engine.h
""" """
TE_DType = { TE_DType = {
torch.int8: tex.DType.kByte, torch.uint8: tex.DType.kByte,
torch.int32: tex.DType.kInt32, torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32, torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16, torch.half: tex.DType.kFloat16,
......
...@@ -3,11 +3,735 @@ ...@@ -3,11 +3,735 @@
# See LICENSE for license information. # See LICENSE for license information.
"""TE FP8 extensions and GEMMs""" """TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union import math
from typing import Optional, Tuple, List, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from .constants import TE_DType from .constants import TE_DType
TORCH_DType = {
tex.DType.kFloat8E4M3: torch.uint8,
tex.DType.kFloat8E5M2: torch.uint8,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
tex.DType.kFloat32: torch.float32,
tex.DType.kInt32: torch.int32,
}
def check_tensor(x: torch.Tensor):
"""Check tensor properties."""
assert (x.is_cuda and x.is_contiguous()
), "Tensor should be a GPU tensor and contiguous."
def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(qkv)
assert (qkv.dtype is dtype
and qkv.dim() == 4
and qkv.shape[1] == 3
), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
and {dtype} dtype."""
def check_q(q: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(q)
assert (q.dtype is dtype
and q.dim() == 3
), """Q should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_kv(kv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(kv)
assert (kv.dtype is dtype
and kv.dim() == 4
and kv.shape[1] == 2
), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
and {dtype} dtype."""
def check_o(o: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(o)
assert (o.dtype is dtype
and o.dim() == 3
), """O and dO should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
"""Check tensor properties."""
check_tensor(stats)
assert (stats.dtype is torch.float32
and stats.dim() == 4
and stats.shape == torch.Size([b, h, s, 1])
), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
shape and float32 dtype."""
def check_cu_seqlens(cu_seqlens: torch.Tensor):
"""Check tensor properties."""
check_tensor(cu_seqlens)
assert (cu_seqlens.dtype is torch.int32
and cu_seqlens.dim() == 1
), """cu_seqlens should be in [batch_size +1] shape and int32 dtype."""
def check_scalar(scalar: torch.Tensor):
"""Check tensor properties."""
check_tensor(scalar)
assert (scalar.dtype is torch.float32
and scalar.dim() <= 1
and scalar.numel() == 1
), "amax/scale/descale tensors should be scalars in float32 dtype."
def check_rng_state(rng_state: torch.Tensor):
"""Check tensor properties."""
check_tensor(rng_state)
assert (rng_state.dtype is torch.int64
and rng_state.numel() == 2
), "rng_state should be [seed, offset] and in int64 dtype."
def fused_attn_fwd_qkvpacked(
is_training: bool,
max_seqlen: int,
cu_seqlens: torch.Tensor,
qkv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed QKV input.
Parameters
----------
is_training: bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens)
cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations
amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations
amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
Returns
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
state of the random number generator;
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_o is not None), "q_scale_o is required for the FP8 API."
assert (amax_s is not None), "amax_s is required for the FP8 API."
assert (amax_o is not None), "amax_o is required for the FP8 API."
check_scalar(d_scale_qkv)
check_scalar(q_scale_s)
check_scalar(q_scale_o)
check_scalar(amax_s)
check_scalar(amax_o)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
)
return output_tensors[0], output_tensors[1:]
def fused_attn_bwd_qkvpacked(
max_seqlen: int,
cu_seqlens: torch.Tensor,
qkv: torch.Tensor,
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
amax_dp: torch.Tensor = None,
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed QKV input.
Parameters
----------
max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q)
cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
d_bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_o: torch.Tensor, default = None
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
input tensor for the quantization of dP in FP8 computations, P = Q * K.T
q_scale_dqkv: torch.Tensor, default = None
input tensor for the quantization of dQKV in FP8 computations
amax_dp: torch.Tensor, default = None
output tensor, amax of dP, used by the next iteration in FP8 computations
amax_dqkv: torch.Tensor, default = None
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
Returns
----------
d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64:
assert (qkv_layout == "qkv_interleaved"
and bias_type == "no_bias"
and attn_mask_type == "padding"
), """The FP8 fused attention API currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
assert (d_scale_s is not None), "d_scale_s is required for the FP8 API."
assert (d_scale_o is not None), "d_scale_o is required for the FP8 API."
assert (d_scale_do is not None), "d_scale_do is required for the FP8 API."
assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API."
assert (amax_dp is not None), "amax_dp is required for the FP8 API."
assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen)
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens,
qkv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
return output_tensors[0]
def fused_attn_fwd_kvpacked(
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
kv: torch.Tensor,
qkv_dtype: tex.DType,
bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for packed KV input.
Parameters
----------
is_training: bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations
amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations
amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
Returns
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
if is_training is False, aux_ctx_tensors = [rng_state]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen, 1], dtype float32
rng_state: torch.Tensor
state of the random number generator;
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = kv.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and (d == 64):
assert False, "The FP8 fused attention API currently only supports packed QKV input."
# BF16/FP16 fused attention API from fmha_v2
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
# BF16/FP16 fused attention API from fmha_v1 apex
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv,
qkv_dtype,
d_scale_qkv,
q_scale_s,
q_scale_o,
amax_s,
amax_o,
bias,
rng_gen,
)
return output_tensors[0], output_tensors[1:]
def fused_attn_bwd_kvpacked(
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
kv: torch.Tensor,
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
amax_dp: torch.Tensor = None,
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "qkv_interleaved",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
Parameters
----------
max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_o: torch.Tensor, default = None
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
input tensor for the quantization of dP in FP8 computations, P = Q * K.T
q_scale_dqkv: torch.Tensor, default = None
input tensor for the quantization of dQKV in FP8 computations
amax_dp: torch.Tensor, default = None
output tensor, amax of dP, used by the next iteration in FP8 computations,
P = Q * K.T
amax_dqkv: torch.Tensor, default = None
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
set_zero: bool, default = True
if True, initializes the output tensor O to zero using the mha_fill method;
if False, doesn't initialize O after its allocation
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
Returns
----------
d_q: torch.Tensor
gradient tensor of Q; same data type and shape as Q
d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = q.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and d == 64:
assert False, "The FP8 fused attention API currently only supports packed QKV input."
############### BF16/FP16 fused attention API from fmha_v2 ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
# add BF/FP16 support for >512 sequence length
assert False, "The BF16/FP16 support for >512 sequence length is coming!"
############### BF16/FP16 fused attention API from fmha_v1 apex ################
elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
# add BF/FP16 support for <=512 sequence length
assert False, "The BF16/FP16 support for <=512 sequence length is coming!"
else:
assert False, "No support for this dtype and max_seqlen combination."
# execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
cu_seqlens_q, cu_seqlens_kv,
q, kv, o, d_o,
qkv_dtype,
aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
return output_tensors
def fp8_gemm( def fp8_gemm(
A: torch.Tensor, A: torch.Tensor,
...@@ -233,9 +957,9 @@ def fp8_cast_transpose_fused( ...@@ -233,9 +957,9 @@ def fp8_cast_transpose_fused(
return_outputs = False return_outputs = False
if cast_out is None or transpose_out is None: if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.int8) cast_out = torch.empty_like(inp, dtype=torch.uint8)
transpose_out = torch.empty( transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8 inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
) )
return_outputs = True return_outputs = True
......
...@@ -88,6 +88,19 @@ size_t product(const std::vector<size_t> &shape) { ...@@ -88,6 +88,19 @@ size_t product(const std::vector<size_t> &shape) {
} }
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros) {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
c10::IntArrayRef ar_shape(shape_int64);
if (init_to_zeros) {
return at::zeros(ar_shape, at::CUDA(GetATenDType(type)));
} else {
return at::empty(ar_shape, at::CUDA(GetATenDType(type)));
}
}
at::Tensor allocateSpace(const NVTEShape &shape, at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type, const transformer_engine::DType type,
bool init_to_zeros) { bool init_to_zeros) {
......
...@@ -15,9 +15,15 @@ ...@@ -15,9 +15,15 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/softmax.h> #include <transformer_engine/softmax.h>
#include <transformer_engine/fused_attn.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <cuda.h> #include <cuda.h>
...@@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { ...@@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kBFloat16; return transformer_engine::DType::kBFloat16;
case at::kBool: case at::kBool:
return transformer_engine::DType::kByte; return transformer_engine::DType::kByte;
case torch::kByte:
return transformer_engine::DType::kByte;
case torch::kInt32:
return transformer_engine::DType::kInt32;
case torch::kInt64:
return transformer_engine::DType::kInt64;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, ...@@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
size_t product(const std::vector<size_t> &shape); size_t product(const std::vector<size_t> &shape);
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros);
at::Tensor allocateSpace(const NVTEShape &shape, at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type, const transformer_engine::DType type,
......
...@@ -9,6 +9,742 @@ ...@@ -9,6 +9,742 @@
#include "comm_gemm_overlap.h" #include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS #endif // NVTE_WITH_USERBUFFERS
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) {
if (qkv_layout == "not_interleaved") {
return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED;
} else if (qkv_layout == "qkv_interleaved") {
return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
} else if (qkv_layout == "kv_interleaved") {
return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
} else {
NVTE_ERROR("Invalid QKV layout. \n");
}
}
// convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
}
// fast zero-fills of tensors
template <typename scalar_t>
__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor,
const int32_t* const start_row,
const size_t num_rows) {
size_t row_stride = gridDim.y * blockDim.x;
size_t row_index = blockIdx.x + static_cast<size_t>(start_row[0]);
size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;
while (row_index < num_rows) {
out_tensor[row_index*row_stride + col_index] = 0;
row_index += gridDim.x;
}
}
// fast zero-fills of tensors
void mha_fill(const at::Tensor &self, const at::Tensor &start_index) {
auto max_tokens = self.size(0);
auto self_2d = self.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
TORCH_CHECK(self.is_contiguous(), "input not contiguous");
TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);
uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y);
dim3 dim_grid(num_blk_x, num_blk_y);
dim3 dim_block(block_size);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
self_2d.scalar_type(), "mha_fill", [&]() {
mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_2d.data_ptr<scalar_t>(),
static_cast<int32_t*>(start_index.data_ptr()),
max_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
// extract seed and offset from PhiloxCudaState
__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
if (arg.captured_) {
rng_state_ptr[0] = static_cast<int64_t>(*arg.seed_.ptr);
rng_state_ptr[1] = static_cast<int64_t>(
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
rng_state_ptr[0] = static_cast<int64_t>(arg.seed_.val);
rng_state_ptr[1] = static_cast<int64_t>(arg.offset_.val);
}
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(
at::CUDAGeneratorImpl* gen,
size_t max_seq_len,
size_t threads_per_cta) {
at::PhiloxCudaState philox_args;
size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
// fused attention FWD with packed QKV
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
return output_tensors;
}
// fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine;
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
if (set_zero) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQKV};
}
// fused attention FWD with packed KV
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen) {
using namespace transformer_engine;
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
size_t threads_per_cta = 128;
at::PhiloxCudaState philox_args = init_philox_state(
gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
// if training, tensors are [M, ZInv]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
// nvte_aux_tensor_pack.size is 0 if inference
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
if (is_training) {
output_tensors.push_back(rng_state);
}
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, M, ZInv, rng_state]; if inference, [O]
return output_tensors;
}
// fused attention BWD with packed KV
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine;
// create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
if (set_zero) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d},
qkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d},
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
DType::kInt32, nullptr, nullptr, nullptr);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward to NVTETensors
// aux_ctx_tensors are [M, ZInv, rng_state]
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQ, dKV};
}
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
transformer_engine::DType A_type, transformer_engine::DType A_type,
...@@ -749,13 +1485,13 @@ at::Tensor cast_to_fp8(const at::Tensor &input, ...@@ -749,13 +1485,13 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
transformer_engine::DType otype transformer_engine::DType otype
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0)); auto input_shape = input.sizes().vec();
size_t H = static_cast<size_t>(input.size(1)); std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(), amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr()); scale_inv.data_ptr());
...@@ -795,12 +1531,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input, ...@@ -795,12 +1531,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType otype transformer_engine::DType otype
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0)); auto input_shape = input.sizes().vec();
size_t H = static_cast<size_t>(input.size(1)); std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype, auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype,
nullptr, nullptr, scale_inv.data_ptr()); nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
...@@ -1066,6 +1802,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1066,6 +1802,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV");
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV");
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
......
...@@ -5,7 +5,95 @@ ...@@ -5,7 +5,95 @@
************************************************************************/ ************************************************************************/
#include "common.h" #include "common.h"
#include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type);
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
......
...@@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor: ...@@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor:
global _cublas_workspace global _cublas_workspace
if _cublas_workspace is None: if _cublas_workspace is None:
_cublas_workspace = torch.empty( _cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda" get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
) )
return _cublas_workspace return _cublas_workspace
...@@ -520,7 +520,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -520,7 +520,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
torch.empty( torch.empty(
shape, shape,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.int8, dtype=torch.uint8,
), ),
) )
setattr( setattr(
...@@ -530,7 +530,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -530,7 +530,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
shape[1], shape[1],
shape[0], shape[0],
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.int8, dtype=torch.uint8,
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment