Unverified Commit 83a4c219 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add FP8 DPA and MHA (#768)



* WIP: fp8 v1 fprop integration
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fprop working for h1; w/ debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup; bprop running but has mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gitlab frontend as submodule
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix after merge; add bias_b/h to caching descriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* distinguish fwd/bwd tensor types for bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust out shape for bwd in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add casting from/to FP8 to DPA module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: bshd_bshd_bshd layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: support all sbhd/bshd layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA/GQA in FP8 v1 API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 705d8e3, with API change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test causal mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict mha_fill for THD format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn with CP and comment out is_alibi code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and self.tp_size/group in FusedAttention()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 6902c94
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 MHA support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to FE v1.3.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for FP8 MHA with different configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emit stats regardless of is_training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix linear when input is not Float8Tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix d_out type when f16 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for fp8_dpa/mha in recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix backend selection to avoid FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use RMSE for FP8 unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace two more transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 initialization to FusedAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rm docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change order of ctxs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back docs and mark as beta
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for tests and docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f69e45be
Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06 Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
......
This diff is collapsed.
...@@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output. # Check output.
atol = {torch.float32 : 2e-4, atol = {torch.float32 : 2.5e-4,
torch.half : 2e-3, torch.half : 2e-3,
torch.bfloat16: 2e-2, torch.bfloat16: 2e-2,
} }
......
...@@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) if (((q_dtype == NVTEDType::kNVTEFloat8E4M3)
|| (q_dtype == NVTEDType::kNVTEFloat8E5M2))
&& (sm_arch_ >= 90) && (sm_arch_ >= 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (
((cudnn_runtime_version >= 8900)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512) && (max_seqlen_q <= 512)
&& (head_dim == 64) && (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || ((cudnn_runtime_version >= 90100)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) { && (max_seqlen_q % 128 == 0)
&& (max_seqlen_kv % 128 == 0)
&& (head_dim == 128)
&& ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else { } else {
...@@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked(
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked( fused_attn_fp8_fwd_qkvpacked(
b, h, max_seqlen, d, b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_output_S, output_O, input_QKV, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
...@@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked( fused_attn_fp8_bwd_qkvpacked(
b, h, max_seqlen, d, b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
input_S, input_output_dP, input_S, input_output_dP,
...@@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
...@@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR(err_msg); NVTE_ERROR(err_msg);
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dKV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
...@@ -662,8 +699,8 @@ void nvte_fused_attn_fwd( ...@@ -662,8 +699,8 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd( fused_attn_fp8_fwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_output_S, output_O, input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
...@@ -775,8 +812,8 @@ void nvte_fused_attn_bwd( ...@@ -775,8 +812,8 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd( fused_attn_fp8_bwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO, input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
input_S, input_output_dP, input_S, input_output_dP,
......
...@@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
scaling_factor, is_training, scaling_factor, is_training,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
tensorType}; tensorType, tensorType};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>, using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
...@@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options; fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes() sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention") .set_name("flash_attention")
.set_is_inference(!is_training) .set_is_inference(false)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
...@@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_O_Matrix); layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1}) .set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1}); .set_stride({h * s_q, s_q, 1, 1});
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K std::shared_ptr<fe::graph::Tensor_attributes>, // K
...@@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O std::shared_ptr<fe::graph::Tensor_attributes> > // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr); auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ? auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
...@@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
{K, devPtrK}, {K, devPtrK},
{V, devPtrV}, {V, devPtrV},
{attn_scale, &scaling_factor}, {attn_scale, &scaling_factor},
{O, devPtrO}}; {O, devPtrO},
{Stats, devPtrSoftmaxStats}};
if (is_training) {
variant_pack[Stats] = devPtrSoftmaxStats;
}
if (is_bias) { if (is_bias) {
variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
...@@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
scaling_factor, true, scaling_factor, true,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
tensorType}; tensorType, tensorType};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>, using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
......
...@@ -19,7 +19,7 @@ namespace transformer_engine { ...@@ -19,7 +19,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale, size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
......
...@@ -14,9 +14,10 @@ namespace transformer_engine { ...@@ -14,9 +14,10 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
...@@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_dO,
...@@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle); cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dKV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
...@@ -63,8 +109,10 @@ void fused_attn_fp8_fwd( ...@@ -63,8 +109,10 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_dO,
......
...@@ -111,19 +111,20 @@ struct FADescriptor_v1 { ...@@ -111,19 +111,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
cudnn_frontend::DataType_t tensor_type; cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type) layout, mask_type, bias_type, fwd_tensor_type, bwd_tensor_type)
< std::tie( < std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.mask_type, rhs.bias_type,
rhs.tensor_type); rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
}; };
......
...@@ -96,7 +96,7 @@ class DelayedScaling: ...@@ -96,7 +96,7 @@ class DelayedScaling:
where `Tensor` is a framework tensor type. where `Tensor` is a framework tensor type.
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad` Whether or not to execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8. GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True` reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8 By default, if `torch.distributed` is initialized, the `amax` value for FP8
...@@ -106,6 +106,20 @@ class DelayedScaling: ...@@ -106,6 +106,20 @@ class DelayedScaling:
GPU maintains local amaxes and scaling factors. To ensure results are GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors. ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes Notes
----- -----
...@@ -116,6 +130,9 @@ class DelayedScaling: ...@@ -116,6 +130,9 @@ class DelayedScaling:
FP8_MAX = maximum_representable_value(fp8_format) FP8_MAX = maximum_representable_value(fp8_format)
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin) new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
""" """
margin: int = 0 margin: int = 0
...@@ -126,6 +143,8 @@ class DelayedScaling: ...@@ -126,6 +143,8 @@ class DelayedScaling:
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True reduce_amax: bool = True
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
......
This diff is collapsed.
...@@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked(
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked(
assert (d_scale_qkv is not None assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention." ), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention." ), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None assert (q_scale_o is not None
...@@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype, cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread, rng_gen, rng_elts_per_thread,
) )
...@@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked( ...@@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked(
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked(
output_tensors = tex.fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill, max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked( ...@@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked( ...@@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked( ...@@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked( ...@@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked( ...@@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked(
input tensor dO (gradient of O); input tensor dO (gradient of O);
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of Q and KV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ and dKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked( ...@@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked( ...@@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked( ...@@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked(
output_tensors = tex.fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -664,6 +699,7 @@ def fused_attn_fwd( ...@@ -664,6 +699,7 @@ def fused_attn_fwd(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -710,6 +746,8 @@ def fused_attn_fwd( ...@@ -710,6 +746,8 @@ def fused_attn_fwd(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -798,12 +836,25 @@ def fused_attn_fwd( ...@@ -798,12 +836,25 @@ def fused_attn_fwd(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd( output_tensors = tex.fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -822,12 +873,14 @@ def fused_attn_bwd( ...@@ -822,12 +873,14 @@ def fused_attn_bwd(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -869,6 +922,8 @@ def fused_attn_bwd( ...@@ -869,6 +922,8 @@ def fused_attn_bwd(
same shape as Q same shape as Q
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype data type of Q, K and V; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -882,6 +937,8 @@ def fused_attn_bwd( ...@@ -882,6 +937,8 @@ def fused_attn_bwd(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -941,6 +998,7 @@ def fused_attn_bwd( ...@@ -941,6 +998,7 @@ def fused_attn_bwd(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -953,8 +1011,8 @@ def fused_attn_bwd( ...@@ -953,8 +1011,8 @@ def fused_attn_bwd(
output_tensors = tex.fused_attn_bwd( output_tensors = tex.fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
......
...@@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (do_gelu const int output_chunk_bytes = (n_chunk * m) * D.element_size();
? (n_chunk * m) * D.element_size()
: (n_chunk * m) * HALF_BYTES);
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers // Get output and workspace data pointers
......
...@@ -32,6 +32,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -32,6 +32,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -51,11 +52,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -51,11 +52,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -74,6 +77,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -74,6 +77,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -95,11 +99,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -95,11 +99,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -119,6 +125,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -119,6 +125,7 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -141,11 +148,13 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -141,11 +148,13 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -233,6 +233,87 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -233,6 +233,87 @@ class _IdentityFunc(torch.autograd.Function):
def backward(ctx, grad): def backward(ctx, grad):
return grad.to(ctx.input_dtype), None return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class Float8Tensor(torch.Tensor): class Float8Tensor(torch.Tensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
...@@ -453,6 +534,12 @@ class Float8Tensor(torch.Tensor): ...@@ -453,6 +534,12 @@ class Float8Tensor(torch.Tensor):
def clone(self) -> Float8Tensor: def clone(self) -> Float8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def view(self, *shape: Tuple[int]) -> Float8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
return _ReshapeFunc.apply(self, shape)
def expand_as(self, other: torch.Tensor): def expand_as(self, other: torch.Tensor):
if other is self: if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes # Note: expand_as is hackily used to create dummy autograd nodes
......
...@@ -202,6 +202,11 @@ class FP8GlobalStateManager: ...@@ -202,6 +202,11 @@ class FP8GlobalStateManager:
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py` # in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction. # to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
if forward and fp8_weights is not None: if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key( autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]) fp8_meta["recipe"], fp8_meta["fp8_group"])
...@@ -217,7 +222,6 @@ class FP8GlobalStateManager: ...@@ -217,7 +222,6 @@ class FP8GlobalStateManager:
key = cls.get_key_in_buffer( key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if key not in cls.global_amax_buffer: if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
......
...@@ -268,6 +268,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -268,6 +268,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys: for meta_key in fp8_meta_tensor_keys:
if meta_key not in self.fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
curr_len = self.fp8_meta[meta_key].amax_history.shape[0] curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len: if length == curr_len:
continue continue
...@@ -568,6 +571,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -568,6 +571,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
...@@ -610,7 +614,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -610,7 +614,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous:
yield inp.contiguous() yield inp.contiguous()
else:
yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
...@@ -645,8 +652,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -645,8 +652,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
R4: bias gradient on R1. R4: bias gradient on R1.
""" """
if isinstance(grad_output, Float8Tensor):
grad_output._data = grad_output._data.contiguous()
else:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) grad_output_mat = grad_output.view(-1, grad_output.shape[-1])
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
...@@ -684,6 +694,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -684,6 +694,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else: else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
if not isinstance(grad_output_mat, Float8Tensor):
cast_to_fp8( cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -691,9 +702,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -691,9 +702,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c, out=grad_output_c,
) )
else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
if not isinstance(grad_output_c, Float8Tensor):
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_t = grad_output_c.transpose_2d()
else: else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None grad_output_t = None
...@@ -702,14 +718,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -702,14 +718,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# FP8 case without gather: cast, transpose, bgrad fused # FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias: if ctx.use_bias:
grad_output_mat_no_fp8 = grad_output_mat
if isinstance(grad_output_mat, Float8Tensor):
grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype)
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat, grad_output_mat_no_fp8,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
) )
else: else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if isinstance(grad_output_mat, Float8Tensor):
grad_output_c = grad_output_mat
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_c, grad_output_t = fp8_cast_transpose_fused( grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -718,12 +741,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -718,12 +741,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
else: else:
grad_output_t = None grad_output_t = None
if not isinstance(grad_output_mat, Float8Tensor):
grad_output_c = cast_to_fp8( grad_output_c = cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
) )
else:
grad_output_c = grad_output_mat
grad_bias = None grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
......
...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo ...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -190,6 +191,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -190,6 +191,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total ln_out = ln_out_total
if fp8: if fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 forward')
bias_dtype = ( bias_dtype = (
torch.bfloat16 torch.bfloat16
if activation_dtype == torch.float32 if activation_dtype == torch.float32
...@@ -230,6 +234,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -230,6 +234,15 @@ class _LayerNormLinear(torch.autograd.Function):
) )
weight_t_fp8 = None weight_t_fp8 = None
if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8)
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, activation_dtype)
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -239,7 +252,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -239,7 +252,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, output_dtype,
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
...@@ -247,8 +260,22 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -247,8 +260,22 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=ub_algo if ub_overlap_ag else None, ub_algo=ub_algo if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
)
if output_dtype == torch.uint8:
out = Float8Tensor(data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
) )
else: else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 forward')
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
...@@ -338,7 +365,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -338,7 +365,6 @@ class _LayerNormLinear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp.shape)
...@@ -352,6 +378,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -352,6 +378,10 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[0]._scale_inv
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
( (
inputmat, inputmat,
...@@ -465,6 +495,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -465,6 +495,9 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None ub_obj = None
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fp8_meta["recipe"], fprop_tensor=True
) )
...@@ -486,7 +519,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -486,7 +519,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_c, grad_output_c._data
if isinstance(grad_output_c, Float8Tensor) else grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
...@@ -503,6 +537,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -503,6 +537,9 @@ class _LayerNormLinear(torch.autograd.Function):
) )
clear_tensor_data(grad_output_c) clear_tensor_data(grad_output_c)
else: else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 backward')
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm( _, _, _ = tex.gemm(
weight, weight,
...@@ -551,7 +588,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -551,7 +588,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t._data
if isinstance(grad_output_t, Float8Tensor) else grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment