Unverified Commit 186cfaf3 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Move dbias from fused attention bwd's input list to its output list (#185)



* move dbias from input list to output list for bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split asserts into three for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/cpp_extensions.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* fix asserts for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* another fix for asserts for bias checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 201279fa
......@@ -68,13 +68,13 @@ void nvte_fused_attn_fwd_qkvpacked(
// NVTE fused attention BWD FP8 with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
......@@ -86,12 +86,12 @@ void nvte_fused_attn_bwd_qkvpacked(
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor*>(dQKV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
......@@ -182,7 +182,6 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
......@@ -190,6 +189,7 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
......@@ -204,13 +204,13 @@ void nvte_fused_attn_bwd_kvpacked(
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor*>(dKV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
......
......@@ -125,13 +125,13 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
......@@ -145,13 +145,13 @@ void nvte_fused_attn_fwd_qkvpacked(
*/
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
......@@ -211,7 +211,6 @@ void nvte_fused_attn_fwd_kvpacked(
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
......@@ -219,6 +218,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
......@@ -236,7 +236,6 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
......@@ -244,6 +243,7 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
......
......@@ -125,8 +125,8 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
......@@ -188,6 +188,13 @@ def fused_attn_fwd_qkvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen, max_seqlen]
), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (bias.dtype == qkv.dtype
), "bias tensor must be in the same dtype as qkv."
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
assert (qkv_layout == "qkv_interleaved"
......@@ -246,7 +253,6 @@ def fused_attn_bwd_qkvpacked(
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -285,9 +291,6 @@ def fused_attn_bwd_qkvpacked(
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
d_bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -326,6 +329,9 @@ def fused_attn_bwd_qkvpacked(
----------
d_qkv: torch.Tensor
gradient tensor of QKV; same data type and shape as QKV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens)
......@@ -402,10 +408,13 @@ def fused_attn_bwd_qkvpacked(
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
if bias_type == "no_bias":
# return d_qkv when bias_type is no_bias
return output_tensors[0]
# otherwise return (d_qkv, d_bias)
return output_tensors
def fused_attn_fwd_kvpacked(
......@@ -454,10 +463,10 @@ def fused_attn_fwd_kvpacked(
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
data type of Q and KV; in tex.DType, not torch.dtype
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
q_scale_s: torch.Tensor, default = None
......@@ -527,6 +536,13 @@ def fused_attn_fwd_kvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (bias.dtype == q.dtype
), "bias tensor must be in the same dtype as q and kv."
# FP8 fused attention API
if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
and (d == 64):
......@@ -577,7 +593,6 @@ def fused_attn_bwd_kvpacked(
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor] = None,
d_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -624,9 +639,6 @@ def fused_attn_bwd_kvpacked(
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
bias: torch.Tensor, default = None
input tensor Bias;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -668,6 +680,9 @@ def fused_attn_bwd_kvpacked(
gradient tensor of Q; same data type and shape as Q
d_kv: torch.Tensor
gradient tensor of KV; same data type and shape as KV
d_bias: torch.Tensor, optional
gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens_q)
......@@ -728,9 +743,11 @@ def fused_attn_bwd_kvpacked(
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv,
amax_dp, amax_dqkv,
d_bias,
)
# returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias)
if bias_type == "no_bias":
return output_tensors[:2]
return output_tensors
def fp8_gemm(
......
......@@ -166,7 +166,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
if ((bias_type != "no_bias") && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
......@@ -276,8 +276,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
// create output tensor dQKV
......@@ -285,9 +284,18 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
if (set_zero) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias;
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
......@@ -332,13 +340,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
......@@ -369,13 +370,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
......@@ -392,13 +393,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(),
max_seqlen,
attn_scale, p_dropout,
......@@ -409,7 +410,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQKV};
return {dQKV, dBias};
}
// fused attention FWD with packed KV
......@@ -473,7 +474,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (Bias.has_value()) {
if ((bias_type != "no_bias") && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
......@@ -593,8 +594,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias) {
c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
// create output tensors dQ and dKV
......@@ -604,9 +604,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != "no_bias") {
dBias = torch::zeros({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias;
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
......@@ -657,13 +666,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if (dBias.has_value()) {
auto bias_shape = dBias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = makeTransformerEngineTensor(
dBias.value().data_ptr(), shape, DType::kFloat32,
nullptr, nullptr, nullptr);
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
......@@ -697,7 +699,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
......@@ -705,6 +706,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
......@@ -723,7 +725,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
nvte_fused_attn_bwd_kvpacked(
te_Q.data(),
te_KV.data(),
te_dBias.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
......@@ -731,6 +732,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dKV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
......@@ -742,7 +744,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQ, dKV};
return {dQ, dKV, dBias};
}
void te_gemm(at::Tensor A,
......
......@@ -48,8 +48,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
......@@ -92,8 +91,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
c10::optional<at::Tensor> amax_dQKV);
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment