Unverified Commit 186cfaf3 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Move dbias from fused attention bwd's input list to its output list (#185)



* move dbias from input list to output list for bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split asserts into three for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/cpp_extensions.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* fix asserts for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* another fix for asserts for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 201279fa
...@@ -68,13 +68,13 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -68,13 +68,13 @@ void nvte_fused_attn_fwd_qkvpacked(
// NVTE fused attention BWD FP8 with packed QKV // NVTE fused attention BWD FP8 with packed QKV
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O, const NVTETensor O,
const NVTETensor dO, const NVTETensor dO,
const NVTETensor S, const NVTETensor S,
NVTETensor dP, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV, NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
size_t max_seqlen, size_t max_seqlen,
float attn_scale, float dropout, float attn_scale, float dropout,
...@@ -86,12 +86,12 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -86,12 +86,12 @@ void nvte_fused_attn_bwd_qkvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S); const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP); Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor*>(dQKV); Tensor *output_dQKV = reinterpret_cast<Tensor*>(dQKV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d] // QKV shape is [total_seqs, 3, h, d]
...@@ -182,7 +182,6 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -182,7 +182,6 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
const NVTETensor KV, const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O, const NVTETensor O,
const NVTETensor dO, const NVTETensor dO,
const NVTETensor S, const NVTETensor S,
...@@ -190,6 +189,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -190,6 +189,7 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensorPack* Aux_CTX_Tensors, const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ, NVTETensor dQ,
NVTETensor dKV, NVTETensor dKV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
...@@ -204,13 +204,13 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -204,13 +204,13 @@ void nvte_fused_attn_bwd_kvpacked(
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S); const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP); Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ); Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor*>(dKV); Tensor *output_dKV = reinterpret_cast<Tensor*>(dKV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d] // Q shape is [total_seqs, h, d]
......
...@@ -125,13 +125,13 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -125,13 +125,13 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim]. * [total_seqs, 3, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens). * it may be >= max(cu_seqlens).
...@@ -145,13 +145,13 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -145,13 +145,13 @@ void nvte_fused_attn_fwd_qkvpacked(
*/ */
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O, const NVTETensor O,
const NVTETensor dO, const NVTETensor dO,
const NVTETensor S, const NVTETensor S,
NVTETensor dP, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV, NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
size_t max_seqlen, size_t max_seqlen,
float attn_scale, float dropout, float attn_scale, float dropout,
...@@ -211,7 +211,6 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -211,7 +211,6 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
...@@ -219,6 +218,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -219,6 +218,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
...@@ -236,7 +236,6 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -236,7 +236,6 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
const NVTETensor KV, const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O, const NVTETensor O,
const NVTETensor dO, const NVTETensor dO,
const NVTETensor S, const NVTETensor S,
...@@ -244,6 +243,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -244,6 +243,7 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensorPack* Aux_CTX_Tensors, const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ, NVTETensor dQ,
NVTETensor dKV, NVTETensor dKV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
......
...@@ -125,8 +125,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -125,8 +125,8 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None bias: torch.Tensor, default = None
input tensor Bias; input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
...@@ -188,6 +188,13 @@ def fused_attn_fwd_qkvpacked( ...@@ -188,6 +188,13 @@ def fused_attn_fwd_qkvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen, max_seqlen]
), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (bias.dtype == qkv.dtype
), "bias tensor must be in the same dtype as qkv."
# FP8 fused attention API # FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64): if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
assert (qkv_layout == "qkv_interleaved" assert (qkv_layout == "qkv_interleaved"
...@@ -246,7 +253,6 @@ def fused_attn_bwd_qkvpacked( ...@@ -246,7 +253,6 @@ def fused_attn_bwd_qkvpacked(
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None, aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -285,9 +291,6 @@ def fused_attn_bwd_qkvpacked( ...@@ -285,9 +291,6 @@ def fused_attn_bwd_qkvpacked(
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
d_bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -326,6 +329,9 @@ def fused_attn_bwd_qkvpacked( ...@@ -326,6 +329,9 @@ def fused_attn_bwd_qkvpacked(
---------- ----------
d_qkv: torch.Tensor d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV gradient tensor of QKV; same data type and shape as QKV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens) check_cu_seqlens(cu_seqlens)
...@@ -402,10 +408,13 @@ def fused_attn_bwd_qkvpacked( ...@@ -402,10 +408,13 @@ def fused_attn_bwd_qkvpacked(
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv, amax_dp, amax_dqkv,
d_bias,
) )
return output_tensors[0] if bias_type == "no_bias":
# return d_qkv when bias_type is no_bias
return output_tensors[0]
# otherwise return (d_qkv, d_bias)
return output_tensors
def fused_attn_fwd_kvpacked( def fused_attn_fwd_kvpacked(
...@@ -454,10 +463,10 @@ def fused_attn_fwd_kvpacked( ...@@ -454,10 +463,10 @@ def fused_attn_fwd_kvpacked(
shape [total_seqs_kv, 2, num_heads, head_dim], shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1] where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of Q and KV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None bias: torch.Tensor, default = None
input tensor Bias; input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
...@@ -527,6 +536,13 @@ def fused_attn_fwd_kvpacked( ...@@ -527,6 +536,13 @@ def fused_attn_fwd_kvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (bias.dtype == q.dtype
), "bias tensor must be in the same dtype as q and kv."
# FP8 fused attention API # FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and (d == 64): and (d == 64):
...@@ -577,7 +593,6 @@ def fused_attn_bwd_kvpacked( ...@@ -577,7 +593,6 @@ def fused_attn_bwd_kvpacked(
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None, aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -624,9 +639,6 @@ def fused_attn_bwd_kvpacked( ...@@ -624,9 +639,6 @@ def fused_attn_bwd_kvpacked(
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -668,6 +680,9 @@ def fused_attn_bwd_kvpacked( ...@@ -668,6 +680,9 @@ def fused_attn_bwd_kvpacked(
gradient tensor of Q; same data type and shape as Q gradient tensor of Q; same data type and shape as Q
d_kv: torch.Tensor d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV gradient tensor of KV; same data type and shape as KV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens_q) check_cu_seqlens(cu_seqlens_q)
...@@ -728,9 +743,11 @@ def fused_attn_bwd_kvpacked( ...@@ -728,9 +743,11 @@ def fused_attn_bwd_kvpacked(
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv, amax_dp, amax_dqkv,
d_bias,
) )
# returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias)
if bias_type == "no_bias":
return output_tensors[:2]
return output_tensors return output_tensors
def fp8_gemm( def fp8_gemm(
......
...@@ -166,7 +166,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -166,7 +166,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if (Bias.has_value()) { if ((bias_type != "no_bias") && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
...@@ -276,8 +276,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -276,8 +276,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV, c10::optional<at::Tensor> amax_dQKV) {
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine; using namespace transformer_engine;
// create output tensor dQKV // create output tensor dQKV
...@@ -285,9 +284,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -285,9 +284,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
if (set_zero) { if (set_zero) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} }
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias; TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
...@@ -332,13 +340,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -332,13 +340,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// convert strings to enums // convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
...@@ -369,13 +370,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -369,13 +370,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_QKV.data(),
te_dBias.data(),
te_O.data(), te_O.data(),
te_dO.data(), te_dO.data(),
te_S.data(), te_S.data(),
te_dP.data(), te_dP.data(),
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_dQKV.data(), te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
...@@ -392,13 +393,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -392,13 +393,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_QKV.data(),
te_dBias.data(),
te_O.data(), te_O.data(),
te_dO.data(), te_dO.data(),
te_S.data(), te_S.data(),
te_dP.data(), te_dP.data(),
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_dQKV.data(), te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
...@@ -409,7 +410,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -409,7 +410,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQKV}; return {dQKV, dBias};
} }
// fused attention FWD with packed KV // fused attention FWD with packed KV
...@@ -473,7 +474,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -473,7 +474,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if (Bias.has_value()) { if ((bias_type != "no_bias") && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
...@@ -593,8 +594,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -593,8 +594,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV, c10::optional<at::Tensor> amax_dQKV) {
const c10::optional<at::Tensor> dBias) {
using namespace transformer_engine; using namespace transformer_engine;
// create output tensors dQ and dKV // create output tensors dQ and dKV
...@@ -604,9 +604,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -604,9 +604,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} }
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias; TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
...@@ -657,13 +666,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -657,13 +666,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
...@@ -697,7 +699,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -697,7 +699,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_Q.data(),
te_KV.data(), te_KV.data(),
te_dBias.data(),
te_O.data(), te_O.data(),
te_dO.data(), te_dO.data(),
te_S.data(), te_S.data(),
...@@ -705,6 +706,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -705,6 +706,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_dQ.data(), te_dQ.data(),
te_dKV.data(), te_dKV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
...@@ -723,7 +725,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -723,7 +725,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_Q.data(),
te_KV.data(), te_KV.data(),
te_dBias.data(),
te_O.data(), te_O.data(),
te_dO.data(), te_dO.data(),
te_S.data(), te_S.data(),
...@@ -731,6 +732,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -731,6 +732,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_dQ.data(), te_dQ.data(),
te_dKV.data(), te_dKV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
...@@ -742,7 +744,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -742,7 +744,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQ, dKV}; return {dQ, dKV, dBias};
} }
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
......
...@@ -48,8 +48,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -48,8 +48,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV, c10::optional<at::Tensor> amax_dQKV);
const c10::optional<at::Tensor> dBias);
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
...@@ -92,8 +91,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -92,8 +91,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV, c10::optional<at::Tensor> amax_dQKV);
const c10::optional<at::Tensor> dBias);
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment