Unverified Commit a8f0fe03 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise scaling linear quantization recipe (#1559)



* Add GEMM logic for blockwise quantized tensors.

GEMM test cases included in pytorch integration.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update NVTE_BLOCK_SCALING for GEMM.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gate feature on CUDA 12.9
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gemm typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove unecessary type converter change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reflect epilogue availability and test supported epilogues.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* GEMM simplifications from recipe branch.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update GEMM DGelu tests to match support depending on output dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Force pow2Scales in GEMM
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add GEMM test to pytorch test suite.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update import for GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add license.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test gemm supported predicate.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use sgemm like interfaces and naming.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rewrite GEMM comment.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Recipe setup for Linear modules.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use 12.9 feature test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Run against tensor dumps from internal library.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update FIXME to TODO with linked issue.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback. Avoid reusing quantizer objects.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update logic in module.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update for PP bug.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test numerics.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update force_power_of_2 scales in the recipe.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update usage method to satisfy upstream changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* fix subchannel recipe in distributed test with bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Edit and cleanup BF16 gather code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test import.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* support columnwise only mode to 1D quantize kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format and move enum
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip alloc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* try async bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Document and type code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytorch lint errors.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Dont set high precision dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add test for sanity and CG; fix CG for sequential?
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Keep make_quantizers API stable

Update num_quantizers instead to pass cuda_graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix import name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rename recipe method.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip grouped linear sanity test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Set usage before BF16 gather.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* refactor for nvte_quantize_v2
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup nvte_quantize_v2
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Test fp32 scales.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Disable CUDA graph.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify layernorm linear
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup layernorm linear.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* LayerNorm linear bwd gather logic.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Communication updates.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update transformer_engine/pytorch/ops/op.py

Apply MR comment change.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>

* Lint fix.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Enable cuda graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reduce chance of spurious failure and reword.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Review suggestions from @timmoon10
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update CPP tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update common.h
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* Update test_float8blockwisetensor.py
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
parent 0da60449
...@@ -16,11 +16,15 @@ ...@@ -16,11 +16,15 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
using transformer_engine::detail::FP8BlockwiseRowwiseOption;
// clang-format off // clang-format off
/* /*
...@@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT ...@@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, const IType* const input, OType* const output_c, OType* const output_t,
OType* const output_t, CType* const tile_scales_inv_c, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
const size_t scale_stride_y, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const size_t scale_t_stride_x, const bool pow_2_scaling) {
const size_t scale_t_stride_y, const float epsilon, bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_transpose, bool pow_2_scaling) { bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>; using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>; using OVec = Vec<OType, kNVecOut>;
union IVec { union IVec {
...@@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
__syncthreads(); __syncthreads();
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
{ if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride; constexpr int num_iterations = kTileDim / r_stride;
...@@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
// Step 3: Transpose, cast and store to output_t // Step 3: Transpose, cast and store to output_t
if (return_transpose) { if (return_columnwise_transpose) {
constexpr int c_stride = constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
...@@ -389,10 +395,15 @@ namespace transformer_engine::detail { ...@@ -389,10 +395,15 @@ namespace transformer_engine::detail {
void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon, SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow2_scale, FP8BlockwiseRowwiseOption rowwise_option,
cudaStream_t stream) { FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length; size_t num_elements = row_length;
...@@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
} }
// Options for scale layout of cuBLAS GEMM kernel. // Options for scale layout of cuBLAS GEMM kernel.
NVTE_CHECK(input.shape.size() == output.shape.size(),
"Input and output must have the same shape.");
size_t scale_stride_x = 0; size_t scale_stride_x = 0;
size_t scale_stride_y = 0; size_t scale_stride_y = 0;
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
size_t scale_t_stride_x = 0; size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0; size_t scale_t_stride_y = 0;
if (return_transpose) { if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE,
"Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
}
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(), NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input."); "output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) { if (output_t.shape.size() > 0) {
...@@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
pow2_scale);) // kAligned columnwise_option, pow2_scale);) // kAligned
) // OutputType ) // OutputType
) // InputType ) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
......
...@@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea ...@@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr; constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output, detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
dbias, workspace, stream); workspace, nullptr, stream);
} }
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
...@@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no ...@@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
NVTE_API_CALL(nvte_quantize_noop); NVTE_API_CALL(nvte_quantize_noop);
using namespace transformer_engine; using namespace transformer_engine;
// Create config with noop tensor
QuantizationConfig quant_config;
quant_config.noop_tensor = noop;
nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
}
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false; constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false; constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
...@@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no ...@@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr; constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output, detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
dbias, workspace, stream); input, grad, output, dbias, workspace, quant_config, stream);
} }
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
...@@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d ...@@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati ...@@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
...@@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati ...@@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati ...@@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat ...@@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat ...@@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......
...@@ -1215,9 +1215,9 @@ namespace detail { ...@@ -1215,9 +1215,9 @@ namespace detail {
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)> float (*OP)(float, const ParamOP &)>
void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
const Tensor *input_tensor; const Tensor *input_tensor;
const Tensor *activation_input_tensor; const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) { if constexpr (IS_DBIAS || IS_DACT) {
...@@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
auto output_tensor = reinterpret_cast<Tensor *>(output); auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias); auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace); auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor(); const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
...@@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
constexpr bool force_pow_2_scales = true; bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
quantize_transpose_square_blockwise( quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, output_tensor->data, output_tensor->columnwise_data, epsilon,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break; break;
} }
...@@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
constexpr bool force_pow_2_scales = true; bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
quantize_transpose_vector_blockwise( float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
output_tensor->data, output_tensor->columnwise_data, ? FP8BlockwiseRowwiseOption::ROWWISE
/*epsilon=*/0.0, : FP8BlockwiseRowwiseOption::NONE;
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream);
break; break;
} }
default: default:
......
...@@ -14,6 +14,7 @@ from ..utils import assert_dim_for_fp8_exec, get_sm_count ...@@ -14,6 +14,7 @@ from ..utils import assert_dim_for_fp8_exec, get_sm_count
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
...@@ -112,6 +113,10 @@ def general_gemm( ...@@ -112,6 +113,10 @@ def general_gemm(
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True
args = ( args = (
A, A,
transa, # transa transa, # transa
......
...@@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer {
public: public:
// Which float8 type is used for q data. // Which float8 type is used for q data.
DType dtype; DType dtype;
private:
// Options about how to quantize the tensor // Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2. // Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon. // Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
private:
int block_scaling_dim = 2; int block_scaling_dim = 2;
public: public:
......
...@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int ...@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else { } else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
} }
......
...@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (te_output.numel() == 0) return out; if (te_output.numel() == 0) return out;
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
...@@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config; // this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return out; return out;
} }
......
...@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
...@@ -166,15 +167,18 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -166,15 +167,18 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu), py::cast(rsigma)};
...@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
...@@ -309,15 +314,18 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -309,15 +314,18 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
......
...@@ -257,12 +257,8 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -257,12 +257,8 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>(); this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast<bool>(), this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
"Pending additional parameters to the nvte_quantize API, " this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
"float8 block quantization requires pow2 scales");
NVTE_CHECK(quantizer.attr("amax_epsilon").cast<float>() == 0.0,
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires amax_epsilon==0");
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim."); "Unsupported block scaling dim.");
} }
......
...@@ -69,6 +69,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, ...@@ -69,6 +69,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
} else { } else {
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
// TODO: switch to nvte_quantize_v2 with advanced numerical options
nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i],
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
......
...@@ -24,10 +24,11 @@ from .constants import dist_group_type ...@@ -24,10 +24,11 @@ from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -937,6 +938,74 @@ def _all_gather_fp8( ...@@ -937,6 +938,74 @@ def _all_gather_fp8(
return out, handle return out, handle
def _all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False, # pylint: disable=unused-argument
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""
All-gather FP8 tensor along first dimension for blockwise quantization.
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
"""
# Input tensor attributes
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
device = inp.device
dtype = inp.dtype
elif isinstance(inp, Float8BlockwiseQTensorBase):
if inp._rowwise_data is not None:
device = inp._rowwise_data.device
elif inp._columnwise_data is not None:
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
f"found {inp.__class__.__name__})"
)
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise NotImplementedError("fp8 blockwise allgather not yet implemented")
def _all_gather_mxfp8( def _all_gather_mxfp8(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -1075,7 +1144,9 @@ def gather_along_first_dim( ...@@ -1075,7 +1144,9 @@ def gather_along_first_dim(
async_op: bool = False, async_op: bool = False,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""All-gather tensors and concatenate along first dimension.""" """
All-gather tensors and concatenate along first dimension.
"""
# Return immediately if no communication is required # Return immediately if no communication is required
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
...@@ -1100,6 +1171,16 @@ def gather_along_first_dim( ...@@ -1100,6 +1171,16 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# FP8 block scaling case, block length = 128
if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
return _all_gather_fp8_blockwise(
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
)
# MXFP8 case # MXFP8 case
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer) assert isinstance(quantizer, MXFP8Quantizer)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import itertools
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
...@@ -19,6 +20,7 @@ from transformer_engine.common.recipe import ( ...@@ -19,6 +20,7 @@ from transformer_engine.common.recipe import (
Format, Format,
MXFP8BlockScaling, MXFP8BlockScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling,
) )
from .constants import dist_group_type from .constants import dist_group_type
...@@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]: ...@@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return False, "Device compute capability 10.0 or higher required for MXFP8 execution." return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
):
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def get_default_fp8_recipe() -> Recipe: def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args.""" """FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
...@@ -109,6 +122,8 @@ class FP8GlobalStateManager: ...@@ -109,6 +122,8 @@ class FP8GlobalStateManager:
skip_fp8_weight_update_tensor = None skip_fp8_weight_update_tensor = None
mxfp8_available = None mxfp8_available = None
reason_for_no_mxfp8 = "" reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None
@classmethod @classmethod
def reset(cls) -> None: def reset(cls) -> None:
...@@ -134,6 +149,8 @@ class FP8GlobalStateManager: ...@@ -134,6 +149,8 @@ class FP8GlobalStateManager:
cls.skip_fp8_weight_update_tensor = None cls.skip_fp8_weight_update_tensor = None
cls.mxfp8_available = None cls.mxfp8_available = None
cls.reason_for_no_mxfp8 = "" cls.reason_for_no_mxfp8 = ""
cls.fp8_block_scaling_available = None
cls.reason_for_no_fp8_block_scaling = ""
@classmethod @classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
...@@ -161,6 +178,15 @@ class FP8GlobalStateManager: ...@@ -161,6 +178,15 @@ class FP8GlobalStateManager:
cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
return cls.mxfp8_available, cls.reason_for_no_mxfp8 return cls.mxfp8_available, cls.reason_for_no_mxfp8
@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
if cls.fp8_block_scaling_available is None:
cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
check_fp8_block_scaling_support()
)
return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling
@staticmethod @staticmethod
def get_meta_tensor_key(forward: bool = True) -> str: def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`.""" """Returns scaling key in `fp8_meta`."""
...@@ -434,6 +460,9 @@ class FP8GlobalStateManager: ...@@ -434,6 +460,9 @@ class FP8GlobalStateManager:
if isinstance(fp8_recipe, MXFP8BlockScaling): if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8 assert mxfp8_available, reason_for_no_mxfp8
if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block
@classmethod @classmethod
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
...@@ -786,8 +815,10 @@ class RecipeState(abc.ABC): ...@@ -786,8 +815,10 @@ class RecipeState(abc.ABC):
cls = MXFP8BlockScalingRecipeState cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling(): elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState
else: else:
raise ValueError("{recipe.__class__.__name__} is not supported") raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls( return cls(
recipe, recipe,
mode=mode, mode=mode,
...@@ -928,3 +959,108 @@ class MXFP8BlockScalingRecipeState(RecipeState): ...@@ -928,3 +959,108 @@ class MXFP8BlockScalingRecipeState(RecipeState):
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
class Float8BlockScalingRecipeState(RecipeState):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe: Float8BlockScaling
mode: str
qx_dtype: tex.DType
qw_dtype: tex.DType
qgrad_dtype: tex.DType
def __init__(
self,
recipe: Float8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.qx_dtype = get_fp8_te_dtype(recipe, True)
self.qw_dtype = get_fp8_te_dtype(recipe, True)
self.qgrad_dtype = get_fp8_te_dtype(recipe, False)
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
...@@ -23,6 +23,7 @@ from ..fp8 import ( ...@@ -23,6 +23,7 @@ from ..fp8 import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8CurrentScalingRecipeState, Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
) )
...@@ -34,8 +35,10 @@ from ..distributed import ( ...@@ -34,8 +35,10 @@ from ..distributed import (
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer from ..tensor import QuantizedTensor, Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -516,6 +519,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -516,6 +519,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state, Float8CurrentScalingRecipeState recipe_state, Float8CurrentScalingRecipeState
): ):
return return
if recipe.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd # 2 (grad_output and grad_input) for bwd
...@@ -858,7 +865,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -858,7 +865,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Quantize the gradient if needed # Quantize the gradient if needed
if not isinstance( if not isinstance(
grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) grad_output,
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
): ):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
...@@ -876,11 +889,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -876,11 +889,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# FP8 without all-gather: fused bgrad + cast + transpose # FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): if isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) if isinstance(quantizer, Float8BlockQuantizer):
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
return grad_output, grad_bias return grad_output, grad_bias
......
...@@ -91,6 +91,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -91,6 +91,8 @@ class _GroupedLinear(torch.autograd.Function):
# TODO Support Float8 Current Scaling # pylint: disable=fixme # TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weights[0].shape[-1] in_features = weights[0].shape[-1]
......
...@@ -57,9 +57,11 @@ from ..tensor.quantized_tensor import ( ...@@ -57,9 +57,11 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
) )
...@@ -138,11 +140,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -138,11 +140,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
...@@ -175,6 +172,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -175,6 +172,18 @@ class _LayerNormLinear(torch.autograd.Function):
columnwise_usage = False columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
with_quantized_norm = (
fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
)
# Apply normalization # Apply normalization
nvtx_range_push(f"{nvtx_label}.norm") nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
...@@ -211,7 +220,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -211,7 +220,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
if fp8: if fp8:
if not with_quantized_norm: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: if ub_overlap_ag_fprop:
...@@ -317,6 +326,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -317,6 +326,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = ( ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
...@@ -327,6 +337,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -327,6 +337,10 @@ class _LayerNormLinear(torch.autograd.Function):
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
...@@ -605,11 +619,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -605,11 +619,14 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=quantizer, quantizer=gather_quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
...@@ -690,6 +707,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -690,6 +707,13 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data # Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor): if isinstance(ln_out_total, QuantizedTensor):
......
...@@ -52,7 +52,6 @@ from ..distributed import ( ...@@ -52,7 +52,6 @@ from ..distributed import (
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
...@@ -62,6 +61,7 @@ from ..tensor.float8_tensor import ( ...@@ -62,6 +61,7 @@ from ..tensor.float8_tensor import (
Float8Tensor, Float8Tensor,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
...@@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
} }
# no activation fusion written yet # no activation fusion written yet
# Per-tensor current scaling: [] # Per-tensor current scaling or fp8 blockwise scaling: []
return { if recipe.float8_current_scaling() or recipe.float8_block_scaling():
"gelu": (tex.gelu, tex.dgelu, None), return {
"relu": (tex.relu, tex.drelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None), "reglu": (tex.reglu, tex.dreglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"srelu": (tex.srelu, tex.dsrelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
} "srelu": (tex.srelu, tex.dsrelu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
def _act_func(activation: str, recipe: Optional[Recipe] = None): def _act_func(activation: str, recipe: Optional[Recipe] = None):
...@@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): ...@@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None):
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Per-tensor current scaling: [] # Per-tensor current scaling or fp8 blockwise scaling: []
funcs = _get_act_func_supported_list(recipe) funcs = _get_act_func_supported_list(recipe)
if activation not in funcs: if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!") raise NotImplementedError("Activation type " + activation + " is not supported!")
...@@ -214,12 +216,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -214,12 +216,20 @@ class _LayerNormMLP(torch.autograd.Function):
with_quantized_norm = ( with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered fp8 and not return_layernorm_output and not return_layernorm_output_gathered
) )
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output # Configure quantizer for norm output
if fp8: if fp8:
if fc1_input_quantizer is None: if fc1_input_quantizer is None:
...@@ -261,12 +271,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -261,12 +271,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8: if fp8:
ln_out = fc1_input_quantizer(ln_out) if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
if fp8: if fp8:
if not with_quantized_norm: if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
...@@ -282,7 +293,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -282,7 +293,10 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer=(fc1_input_quantizer if fp8 else None), quantizer=(fc1_input_quantizer if fp8 else None),
) )
else: else:
if fp8 and not with_quantized_norm: # NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
...@@ -336,6 +350,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -336,6 +350,7 @@ class _LayerNormMLP(torch.autograd.Function):
# - bias_gelu_fusion - only for full precision. # - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
if activation != "gelu": if activation != "gelu":
# blockwise scaled gemms don't support gemm_gelu_fusion in fwd.
gemm_gelu_fusion = bias_gelu_fusion = False gemm_gelu_fusion = bias_gelu_fusion = False
else: else:
if fp8: if fp8:
...@@ -376,7 +391,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -376,7 +391,12 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, _, fc1_out, _ = fc1_outputs act_out, _, fc1_out, _ = fc1_outputs
else: else:
fc1_out, *_ = fc1_outputs fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, fc2_input_quantizer) if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise.
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(fc1_out) clear_tensor_data(fc1_out)
...@@ -462,6 +482,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -462,6 +482,8 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output: if not return_layernorm_output:
clear_tensor_data(ln_out) clear_tensor_data(ln_out)
ln_out = None ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad: if not fc2_weight.requires_grad:
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
...@@ -490,6 +512,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -490,6 +512,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
...@@ -505,6 +528,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -505,6 +528,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.activation = activation ctx.activation = activation
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
...@@ -696,11 +720,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -696,11 +720,12 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=quantizer, quantizer=gather_quantizer,
) )
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -712,12 +737,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -712,12 +737,13 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# There are 5 possible fusion paths # There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
# 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize
# 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion = ( fc2_dgrad_gemm_gelu_fusion = (
not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion)
) )
...@@ -753,6 +779,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -753,6 +779,9 @@ class _LayerNormMLP(torch.autograd.Function):
if isinstance(grad_output, QuantizedTensor): if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( fc2_wgrad, fc2_bias_grad_, *_ = general_gemm(
act_out, act_out,
grad_output, grad_output,
...@@ -764,14 +793,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -764,14 +793,18 @@ class _LayerNormMLP(torch.autograd.Function):
), ),
quantization_params=None, # wgrad in high precision quantization_params=None, # wgrad in high precision
layout="NT", layout="NT",
grad=True, grad=grad_arg,
bias=fc2_bias if fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if fc2_bias_grad is None:
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None:
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out) clear_tensor_data(act_out)
# bias computation # bias computation
...@@ -808,7 +841,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -808,7 +841,14 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision ) # activation in high precision
if ctx.fp8: if ctx.fp8:
fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) # TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer
)
else: else:
fuse_gemm_and_bias_fc1_wgrad = ( fuse_gemm_and_bias_fc1_wgrad = (
True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
...@@ -904,6 +944,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -904,6 +944,13 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data # Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor): if isinstance(ln_out_total, QuantizedTensor):
...@@ -1556,7 +1603,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1556,7 +1603,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_weight_quantizer.internal = True fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
) )
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
......
...@@ -60,9 +60,10 @@ from ..tensor.quantized_tensor import ( ...@@ -60,9 +60,10 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -130,6 +131,10 @@ class _Linear(torch.autograd.Function): ...@@ -130,6 +131,10 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
own_quantized_input = False own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
...@@ -143,19 +148,27 @@ class _Linear(torch.autograd.Function): ...@@ -143,19 +148,27 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if with_input_all_gather_nccl:
if not isinstance(inputmat, QuantizedTensor): if force_hp_input_gather:
columnwise_usage = backward_needs_input and isinstance( input_quantizer.set_usage(rowwise=True, columnwise=False)
input_quantizer, MXFP8Quantizer inputmat_total, _ = gather_along_first_dim(
inputmat, tp_group, quantizer=input_quantizer
)
else:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
) )
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
else: else:
if ( if (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
...@@ -277,6 +290,8 @@ class _Linear(torch.autograd.Function): ...@@ -277,6 +290,8 @@ class _Linear(torch.autograd.Function):
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -323,8 +338,9 @@ class _Linear(torch.autograd.Function): ...@@ -323,8 +338,9 @@ class _Linear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
...@@ -520,11 +536,12 @@ class _Linear(torch.autograd.Function): ...@@ -520,11 +536,12 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
ctx.tp_group, ctx.tp_group,
async_op=True, async_op=True,
quantizer=quantizer, quantizer=gather_quantizer,
) )
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
...@@ -610,6 +627,13 @@ class _Linear(torch.autograd.Function): ...@@ -610,6 +627,13 @@ class _Linear(torch.autograd.Function):
if inputmat_total_work is not None: if inputmat_total_work is not None:
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data # Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor): if isinstance(inputmat_total, QuantizedTensor):
......
...@@ -23,6 +23,7 @@ from ...fp8 import FP8GlobalStateManager ...@@ -23,6 +23,7 @@ from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -483,6 +484,12 @@ class BasicLinear(BasicOperation): ...@@ -483,6 +484,12 @@ class BasicLinear(BasicOperation):
"Attempting to generate MXFP8 output tensor, " "Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported" "but GEMM with MXFP8 output is not supported"
) )
if isinstance(output_quantizer, Float8BlockQuantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if output_quantizer is not None: if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
......
...@@ -17,6 +17,7 @@ from transformer_engine.common.recipe import Recipe ...@@ -17,6 +17,7 @@ from transformer_engine.common.recipe import Recipe
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
fp8_autocast, fp8_autocast,
...@@ -219,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -219,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
if num_quantizers == 0: if num_quantizers == 0:
continue continue
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct quantization recipe state # Construct quantization recipe state
recipe_state = RecipeState.create( recipe_state = RecipeState.create(
recipe, recipe,
...@@ -260,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -260,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
continue continue
recipe_state = self._fp8_metas[mode][fp8_meta_key] recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = ( need_to_reset_recipe_state = (
recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
)
)
if need_to_reset_recipe_state: if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe) self._reset_quantization_recipe_state(recipe=recipe)
return return
......
...@@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: ...@@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase:
def __new__( def __new__(
cls, cls,
*args, *args,
rowwise_data: torch.Tensor, rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: torch.Tensor, rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
...@@ -71,10 +71,16 @@ class Float8BlockwiseQTensorBase: ...@@ -71,10 +71,16 @@ class Float8BlockwiseQTensorBase:
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""Prepare the tensor base for saving for backward""" """
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors = [self._rowwise_data, self._columnwise_data] tensors = [self._rowwise_data, self._columnwise_data]
self._rowwise_data = None
self._columnwise_data = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment