Unverified Commit a8f0fe03 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise scaling linear quantization recipe (#1559)



* Add GEMM logic for blockwise quantized tensors.

GEMM test cases included in pytorch integration.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update NVTE_BLOCK_SCALING for GEMM.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gate feature on CUDA 12.9
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gemm typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove unecessary type converter change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reflect epilogue availability and test supported epilogues.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* GEMM simplifications from recipe branch.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update GEMM DGelu tests to match support depending on output dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Force pow2Scales in GEMM
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add GEMM test to pytorch test suite.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update import for GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add license.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test gemm supported predicate.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use sgemm like interfaces and naming.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rewrite GEMM comment.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Recipe setup for Linear modules.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use 12.9 feature test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Run against tensor dumps from internal library.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update FIXME to TODO with linked issue.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback. Avoid reusing quantizer objects.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update logic in module.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update for PP bug.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test numerics.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update force_power_of_2 scales in the recipe.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update usage method to satisfy upstream changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* fix subchannel recipe in distributed test with bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Edit and cleanup BF16 gather code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test import.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* support columnwise only mode to 1D quantize kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format and move enum
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip alloc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* try async bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Document and type code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytorch lint errors.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Dont set high precision dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add test for sanity and CG; fix CG for sequential?
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Keep make_quantizers API stable

Update num_quantizers instead to pass cuda_graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix import name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rename recipe method.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip grouped linear sanity test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Set usage before BF16 gather.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* refactor for nvte_quantize_v2
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup nvte_quantize_v2
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Test fp32 scales.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Disable CUDA graph.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify layernorm linear
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup layernorm linear.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* LayerNorm linear bwd gather logic.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Communication updates.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update transformer_engine/pytorch/ops/op.py

Apply MR comment change.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>

* Lint fix.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Enable cuda graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reduce chance of spurious failure and reword.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Review suggestions from @timmoon10
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update CPP tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update common.h
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* Update test_float8blockwisetensor.py
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
parent 0da60449
......@@ -16,11 +16,15 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
using transformer_engine::detail::FP8BlockwiseRowwiseOption;
// clang-format off
/*
......@@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock)
block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c,
OType* const output_t, CType* const tile_scales_inv_c,
CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y,
const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
bool return_transpose, bool pow_2_scaling) {
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
......@@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
__syncthreads();
// Step 2: Cast and store to output_c
{
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
......@@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
// Step 3: Transpose, cast and store to output_t
if (return_transpose) {
if (return_columnwise_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
......@@ -389,10 +395,15 @@ namespace transformer_engine::detail {
void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow2_scale,
cudaStream_t stream) {
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
......@@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
}
// Options for scale layout of cuBLAS GEMM kernel.
NVTE_CHECK(input.shape.size() == output.shape.size(),
"Input and output must have the same shape.");
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE,
"Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
}
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
......@@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose,
pow2_scale);) // kAligned
) // OutputType
) // InputType
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
......
......@@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output,
dbias, workspace, stream);
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
......@@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
NVTE_API_CALL(nvte_quantize_noop);
using namespace transformer_engine;
// Create config with noop tensor
QuantizationConfig quant_config;
quant_config.noop_tensor = noop;
nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
}
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
......@@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output,
dbias, workspace, stream);
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
......@@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
......@@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......
......@@ -1215,9 +1215,9 @@ namespace detail {
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output,
NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
const Tensor *input_tensor;
const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) {
......@@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
switch (output_tensor->scaling_mode) {
......@@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
constexpr bool force_pow_2_scales = true;
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
......@@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
constexpr bool force_pow_2_scales = true;
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
? FP8BlockwiseRowwiseOption::ROWWISE
: FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream);
break;
}
default:
......
......@@ -14,6 +14,7 @@ from ..utils import assert_dim_for_fp8_exec, get_sm_count
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = [
"general_gemm",
......@@ -112,6 +113,10 @@ def general_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True
args = (
A,
transa, # transa
......
......@@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
DType dtype;
private:
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
private:
int block_scaling_dim = 2;
public:
......
......@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
}
......
......@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (te_output.numel() == 0) return out;
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
......@@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(),
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
return out;
}
......
......@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
......@@ -166,15 +167,18 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
}
return {out, py::cast(mu), py::cast(rsigma)};
......@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
......@@ -309,15 +314,18 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
}
return {out, py::none(), py::cast(rsigma)};
......
......@@ -257,12 +257,8 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast<bool>(),
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires pow2 scales");
NVTE_CHECK(quantizer.attr("amax_epsilon").cast<float>() == 0.0,
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires amax_epsilon==0");
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
}
......
......@@ -69,6 +69,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
} else {
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
// TODO: switch to nvte_quantize_v2 with advanced numerical options
nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i],
at::cuda::getCurrentCUDAStream());
}
......
......@@ -24,10 +24,11 @@ from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -937,6 +938,74 @@ def _all_gather_fp8(
return out, handle
def _all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False, # pylint: disable=unused-argument
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""
All-gather FP8 tensor along first dimension for blockwise quantization.
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
"""
# Input tensor attributes
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
device = inp.device
dtype = inp.dtype
elif isinstance(inp, Float8BlockwiseQTensorBase):
if inp._rowwise_data is not None:
device = inp._rowwise_data.device
elif inp._columnwise_data is not None:
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
f"found {inp.__class__.__name__})"
)
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise NotImplementedError("fp8 blockwise allgather not yet implemented")
def _all_gather_mxfp8(
inp: torch.Tensor,
process_group: dist_group_type,
......@@ -1075,7 +1144,9 @@ def gather_along_first_dim(
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""All-gather tensors and concatenate along first dimension."""
"""
All-gather tensors and concatenate along first dimension.
"""
# Return immediately if no communication is required
world_size = get_distributed_world_size(process_group)
......@@ -1100,6 +1171,16 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# FP8 block scaling case, block length = 128
if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
return _all_gather_fp8_blockwise(
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
)
# MXFP8 case
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer)
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
import abc
import itertools
import os
from contextlib import contextmanager
from collections import deque
......@@ -19,6 +20,7 @@ from transformer_engine.common.recipe import (
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
Float8BlockScaling,
)
from .constants import dist_group_type
......@@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
):
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
......@@ -109,6 +122,8 @@ class FP8GlobalStateManager:
skip_fp8_weight_update_tensor = None
mxfp8_available = None
reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None
@classmethod
def reset(cls) -> None:
......@@ -134,6 +149,8 @@ class FP8GlobalStateManager:
cls.skip_fp8_weight_update_tensor = None
cls.mxfp8_available = None
cls.reason_for_no_mxfp8 = ""
cls.fp8_block_scaling_available = None
cls.reason_for_no_fp8_block_scaling = ""
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
......@@ -161,6 +178,15 @@ class FP8GlobalStateManager:
cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
return cls.mxfp8_available, cls.reason_for_no_mxfp8
@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
if cls.fp8_block_scaling_available is None:
cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
check_fp8_block_scaling_support()
)
return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
......@@ -434,6 +460,9 @@ class FP8GlobalStateManager:
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8
if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block
@classmethod
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
......@@ -786,8 +815,10 @@ class RecipeState(abc.ABC):
cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls(
recipe,
mode=mode,
......@@ -928,3 +959,108 @@ class MXFP8BlockScalingRecipeState(RecipeState):
from .tensor.mxfp8_tensor import MXFP8Quantizer
return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
class Float8BlockScalingRecipeState(RecipeState):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe: Float8BlockScaling
mode: str
qx_dtype: tex.DType
qw_dtype: tex.DType
qgrad_dtype: tex.DType
def __init__(
self,
recipe: Float8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.qx_dtype = get_fp8_te_dtype(recipe, True)
self.qw_dtype = get_fp8_te_dtype(recipe, True)
self.qgrad_dtype = get_fp8_te_dtype(recipe, False)
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
......@@ -23,6 +23,7 @@ from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
)
......@@ -34,8 +35,10 @@ from ..distributed import (
)
from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -516,6 +519,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state, Float8CurrentScalingRecipeState
):
return
if recipe.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
......@@ -858,7 +865,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.ub_overlap_ag:
# Quantize the gradient if needed
if not isinstance(
grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)
grad_output,
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
):
grad_output = quantizer(grad_output)
......@@ -876,11 +889,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None
if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
if isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
if isinstance(quantizer, Float8BlockQuantizer):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output)
return grad_output, grad_bias
......
......@@ -91,6 +91,8 @@ class _GroupedLinear(torch.autograd.Function):
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling")
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
......
......@@ -57,9 +57,11 @@ from ..tensor.quantized_tensor import (
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import (
general_gemm,
)
......@@ -138,11 +140,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
......@@ -175,6 +172,18 @@ class _LayerNormLinear(torch.autograd.Function):
columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
with_quantized_norm = (
fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
)
# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
......@@ -211,7 +220,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total = input_quantizer(ln_out_total)
else:
if fp8:
if not with_quantized_norm:
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop:
......@@ -317,6 +326,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
......@@ -327,6 +337,10 @@ class _LayerNormLinear(torch.autograd.Function):
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
......@@ -605,11 +619,14 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
......@@ -690,6 +707,13 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
......
......@@ -52,7 +52,6 @@ from ..distributed import (
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
)
from ..constants import dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
......@@ -62,6 +61,7 @@ from ..tensor.float8_tensor import (
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.quantized_tensor import (
......@@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
}
# no activation fusion written yet
# Per-tensor current scaling: []
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, None),
}
# Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling():
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
def _act_func(activation: str, recipe: Optional[Recipe] = None):
......@@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None):
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Per-tensor current scaling: []
# Per-tensor current scaling or fp8 blockwise scaling: []
funcs = _get_act_func_supported_list(recipe)
if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!")
......@@ -214,12 +216,20 @@ class _LayerNormMLP(torch.autograd.Function):
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
......@@ -261,12 +271,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
ln_out = fc1_input_quantizer(ln_out)
if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
if fp8:
if not with_quantized_norm:
if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -282,7 +293,10 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer=(fc1_input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
......@@ -336,6 +350,7 @@ class _LayerNormMLP(torch.autograd.Function):
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
if activation != "gelu":
# blockwise scaled gemms don't support gemm_gelu_fusion in fwd.
gemm_gelu_fusion = bias_gelu_fusion = False
else:
if fp8:
......@@ -376,7 +391,12 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, _, fc1_out, _ = fc1_outputs
else:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, fc2_input_quantizer)
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise.
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
......@@ -462,6 +482,8 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
......@@ -490,6 +512,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
......@@ -505,6 +528,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
......@@ -696,11 +720,12 @@ class _LayerNormMLP(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
else:
ln_out_total = ln_out
......@@ -712,12 +737,13 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# There are 5 possible fusion paths
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
# 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize
# 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion = (
not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion)
)
......@@ -753,6 +779,9 @@ class _LayerNormMLP(torch.autograd.Function):
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm(
act_out,
grad_output,
......@@ -764,14 +793,18 @@ class _LayerNormMLP(torch.autograd.Function):
),
quantization_params=None, # wgrad in high precision
layout="NT",
grad=True,
bias=fc2_bias if fc2_bias_grad is None else None,
grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if fc2_bias_grad is None:
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None:
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out)
# bias computation
......@@ -808,7 +841,14 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision
if ctx.fp8:
fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer)
# TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer
)
else:
fuse_gemm_and_bias_fc1_wgrad = (
True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
......@@ -904,6 +944,13 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
......@@ -1556,7 +1603,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer)
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
......
......@@ -60,9 +60,10 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
__all__ = ["Linear"]
......@@ -130,6 +131,10 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
......@@ -143,19 +148,27 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
if force_hp_input_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat, tp_group, quantizer=input_quantizer
)
else:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
else:
if (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
......@@ -277,6 +290,8 @@ class _Linear(torch.autograd.Function):
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -323,8 +338,9 @@ class _Linear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.activation_dtype = activation_dtype
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
......@@ -520,11 +536,12 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
......@@ -610,6 +627,13 @@ class _Linear(torch.autograd.Function):
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor):
......
......@@ -23,6 +23,7 @@ from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
......@@ -483,6 +484,12 @@ class BasicLinear(BasicOperation):
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
......
......@@ -17,6 +17,7 @@ from transformer_engine.common.recipe import Recipe
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
......@@ -219,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
if num_quantizers == 0:
continue
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct quantization recipe state
recipe_state = RecipeState.create(
recipe,
......@@ -260,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)
) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
(recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
)
)
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
......
......@@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase:
def __new__(
cls,
*args,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
......@@ -71,10 +71,16 @@ class Float8BlockwiseQTensorBase:
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""Prepare the tensor base for saving for backward"""
"""
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors = [self._rowwise_data, self._columnwise_data]
self._rowwise_data = None
self._columnwise_data = None
return tensors, self
def restore_from_saved(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment