Unverified Commit 77fa1e59 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] Enabling Per-Tensor Current Scaling Recipe (#1471)



* check in per-tensor current scaling full recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

setup basics of current scaling quantizer in python level
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

add test case for current scaling dequantize
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

finish linear layer fwd bwd test, determined error with bf16
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

achieved zero tolerance for Linear by specify gemm use_split_accumulator config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

enable layernormlinear with current scaling, pass bitwise test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

refactor test case code
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

make current scaling quantizers distrbuted, pass distributed linear&layernormlinear tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

bug fix: use cached fp8 recipe in backward
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

fix layernorm_mlp with current scaling, fix activation_helper with current scaling
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

support detailed numerical settings from recipe to quantization kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

resolving MR comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

recipe naming
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, remove IS_CURRENT_SCALING template from kernels
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, make current scaling c++ test cases
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add current scaling to test_numerics.py, skip act recomp and grouped linear
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmark for quantizer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmarks for linear layer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix, typo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more mr comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* avoid potential race condition by not using from_blob to construct amax tensor in C++
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Debug linter warnings and license check
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug import error in FP8 tensor test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug compilation error with CUDA 12.1 for Turing
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, fix activation cast fusion
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments, add NVTEQuantizationParams for compute scale
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove is_current_scaling check totally from common folder
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove benchmarks, will contribute in another repo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* adjust cs default recipe config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust comments in test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Remove current scaling mode from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor current-scaling-specific logic in core C++ lib

Move amax and scale update functions out of casting functions, and put into dedicated current-scaling source file. Add general API for accessing quantization config object.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing header in C++ tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable test config with FP8 transpose on Blackwell
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error in C++ test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 2a95efd3
......@@ -39,6 +39,27 @@ class Format(Enum):
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
@dataclass(frozen=True)
class MMParams:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""
use_split_accumulator: bool = True
@dataclass(frozen=True)
class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
"""
power_2_scale: bool = False
amax_epsilon: float = 0.0
class Recipe:
"""
Base recipe class.
......@@ -52,6 +73,10 @@ class Recipe:
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
def float8_current_scaling(self):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
@dataclass()
class DelayedScaling(Recipe):
......@@ -161,6 +186,75 @@ class DelayedScaling(Recipe):
)
@dataclass()
class Float8CurrentScaling(Recipe):
"""
Use the per-tensor current scaling factor strategy.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of gradient tensor dY
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
"""
fp8_format: Format = Format.HYBRID
fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return (
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, "
f"fp8_gemm_fprop={self.fp8_gemm_fprop}, "
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
@dataclass()
class MXFP8BlockScaling(Recipe):
"""
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <algorithm>
#include <limits>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
namespace {
constexpr int amax_kernel_threads = 512;
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = 0.f;
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const InputType val = static_cast<InputType>(loader.separate()[i]);
__builtin_assume(max >= InputType{0.f});
if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
#if __CUDA_ARCH__ >= 800
max = __hmax(__habs(val), max);
#else // Turing
max = static_cast<__nv_bfloat16>(
fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max)));
#endif
} else if constexpr (std::is_same_v<InputType, __half>) {
max = __hmax(__habs(val), max);
} else {
max = fmaxf(fabsf(val), max);
}
}
}
// Reduce amax over block
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
atomicMaxFloat(amax, max);
}
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, sizeof(float), stream);
// Return immediately if tensor is empty
if (N == 0) {
return;
}
// Figure out alignment
auto align = CheckAlignment(N, nvec, input);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
// Figure out CUDA blocks
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
// Launch kernel
switch (align) {
case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
break;
}
}
// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace
} // namespace transformer_engine
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
using namespace transformer_engine;
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
const auto &input = *reinterpret_cast<const Tensor *>(input_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
to_string(input.scaling_mode));
NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
"Input tensor for amax computation must be unquantized, but got dtype=",
to_string(input.data.dtype));
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");
CheckInputTensor(input, "input_compute_amax");
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(output.amax.numel() == 1,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape=",
output.amax.shape, ")");
NVTE_CHECK(output.amax.dptr != nullptr,
"Output tensor for amax computation has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Output tensor for amax computation has invalid amax tensor "
"(expected FP32, got dtype=",
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax");
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
stream);); // NOLINT(*)
}
namespace transformer_engine {
namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales,
const float epsilon) {
float amax = *amax_ptr;
if (amax < epsilon) {
amax = epsilon;
}
float scale = 1.f;
if (isinf(amax) || amax == 0.f) {
*scale_ptr = scale;
return;
}
scale = max_fp8 / amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
// use fp32 max to represent the scale
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0);
__builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits);
}
*scale_ptr = scale;
}
} // namespace
} // namespace transformer_engine
void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_,
cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_scale_from_amax);
using namespace transformer_engine;
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(is_fp8_dtype(output.data.dtype),
"Tensor must be FP8, but got dtype=", to_string(output.data.dtype));
NVTE_CHECK(output.amax.numel() == 1,
"Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape,
")");
NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Tensor has invalid amax tensor (expected FP32, got dtype=",
to_string(output.amax.dtype), ")");
NVTE_CHECK(output.scale.numel() == 1,
"Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape,
")");
NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data");
NVTE_CHECK(output.scale.dtype == DType::kFloat32,
"Tensor has invalid scale tensor (expected FP32, got dtype=",
to_string(output.scale.dtype), ")");
// Check config
NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)");
const auto &config = *reinterpret_cast<const QuantizationConfig *>(config_);
// Maximum FP8 value
float max_fp8 = 0.f;
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8,
config.force_pow_2_scales, config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -6,6 +6,7 @@
#include <transformer_engine/transformer_engine.h>
#include <cstring>
#include <iostream>
#include "common.h"
......@@ -150,8 +151,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor");
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
......@@ -410,3 +410,79 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream);
}
}
NVTEQuantizationConfig nvte_create_quantization_config() {
return new transformer_engine::QuantizationConfig;
}
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
}
}
......@@ -249,7 +249,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output.dtype(), OutputType,
if (is_delayed_tensor_scaling(output.scaling_mode)) {
if (is_tensor_scaling(output.scaling_mode)) {
// delayed scaling and current scaling are two variants of per-tensor scaling
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
......@@ -323,6 +325,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks =
(DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr),
......
......@@ -1054,8 +1054,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(input.data.dptr),
......@@ -1079,8 +1078,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp
input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
......@@ -1164,15 +1162,23 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop,
Tensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) +
if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
// zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" on GPU with compute capability < 10.0.");
}
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DACT) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
} else {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
......
......@@ -844,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war
staging[warpid] = my_warp_max;
}
__syncthreads();
compute_t result = 0;
compute_t result = 0.f;
if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max);
......
......@@ -24,6 +24,16 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}
AttnMaskTypes = (
"no_mask",
"padding",
......
......@@ -46,15 +46,22 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) {
NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!");
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
// check for both quantizer & tensor type:
// mxfp8 tensor -> mxfp8 quantizer
// float8 tensor -> delayed scaling quantizer OR current scaling quantizer
// also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer
for (auto [check_type, check_quantizer_type, create_tensor, _] :
detail::custom_types_converters) {
if (check_type(tensor.ptr())) {
NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()),
"Unexpected quantization params type.");
if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) {
continue;
}
auto x = create_tensor(tensor, my_quantizer.get());
return x;
}
}
NVTE_CHECK(dynamic_cast<NoneQuantizer*>(my_quantizer.get()) != nullptr,
"Unexpected quantization params type.");
// Regular pyTorch tensor
at::Tensor torch_tensor = tensor.cast<at::Tensor>();
......
......@@ -50,6 +50,9 @@
namespace transformer_engine::pytorch {
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
......@@ -136,6 +139,29 @@ class Float8Quantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class Float8CurrentScalingQuantizer : public Quantizer {
public:
at::Tensor scale;
at::Tensor scale_inv;
at::Tensor amax;
DType dtype;
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
int amax_reduction_size;
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer {
public:
DType dtype;
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "common.h"
#include "extensions.h"
#include "pybind.h"
......@@ -24,7 +25,35 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions.");
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
} else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
}
return out;
}
......
......@@ -45,6 +45,29 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
}
if (te_output.numel() == 0) return out;
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
}
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(),
at::cuda::getCurrentCUDAStream());
......
......@@ -24,6 +24,7 @@ namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorBasePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
......@@ -33,6 +34,8 @@ void init_float8_extension() {
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer"));
Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module =
......
......@@ -140,6 +140,123 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
return {std::move(tensor), std::move(ret)};
}
Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
: Quantizer(quantizer) {
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group");
const c10::intrusive_ptr<dist_group_type> amax_reduction_group =
amax_reduction_group_obj.is_none()
? nullptr
: amax_reduction_group_obj.cast<c10::intrusive_ptr<dist_group_type>>();
const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast<int>();
this->amax = amax;
this->scale = scale;
this->dtype = type;
this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group;
this->amax_reduction_size = amax_reduction_size;
// fp8 current scaling specific quantization params
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
}
void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them)
tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
getTensorShape(scale));
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
// quantize output and its transpose
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
std::vector<int64_t> scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
//unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer
opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts);
py::object ret;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
}
......
......@@ -12,7 +12,7 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return;
}
......
......@@ -23,7 +23,8 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
if (transpose_valid) {
transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>();
}
// In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer
// whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING
auto ret = TensorWrapper(quantizer->get_scaling_mode());
ret.set_rowwise_data(data.data_ptr(), dtype, shape);
......
......@@ -21,6 +21,7 @@ namespace transformer_engine::pytorch {
extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8QuantizerClass;
extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8QuantizerClass;
......@@ -33,13 +34,17 @@ void init_mxfp8_extension();
namespace detail {
inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }
inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }
inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass;
}
inline bool IsFloat8Tensor(PyObject *obj) {
return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass;
}
inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
......@@ -61,9 +66,11 @@ inline bool IsFloatingPointType(at::ScalarType type) {
}
constexpr std::array custom_types_converters = {
std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor,
std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8Quantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor,
std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8CurrentScalingQuantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>)};
} // namespace detail
......
......@@ -21,7 +21,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
......@@ -859,7 +859,10 @@ def _all_gather_fp8(
# Quantize input tensor if needed
if not isinstance(input_, Float8TensorBase):
assert isinstance(quantizer, Float8Quantizer)
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False)
input_ = quantizer(input_)
......@@ -867,7 +870,7 @@ def _all_gather_fp8(
# Construct output tensor
out: Float8TensorBase
if isinstance(quantizer, Float8Quantizer):
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
dtype = torch.float32
device = "cuda"
if isinstance(input_, Float8Tensor):
......@@ -885,6 +888,9 @@ def _all_gather_fp8(
out._transpose_invalid = True
else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from input_ to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from input_ to out
out._scale_inv = input_._scale_inv
# Perform communication
......@@ -999,8 +1005,10 @@ def gather_along_first_dim(
out_shape = list(input_.size())
out_shape[0] *= world_size
# FP8 case
if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer):
# FP8 case: delayed scaling or current scaling
if isinstance(input_, Float8TensorBase) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
return _all_gather_fp8(
input_,
process_group,
......
......@@ -13,7 +13,13 @@ from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling
from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
)
from .constants import dist_group_type
from .utils import get_device_compute_capability
......@@ -198,6 +204,8 @@ class FP8GlobalStateManager:
fp8_meta: Dict[str, Any],
) -> None:
"""
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
......@@ -211,7 +219,8 @@ class FP8GlobalStateManager:
wrapper. For non CG case, it's called from within the module.
"""
if fp8_meta["recipe"].mxfp8():
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Every module must call this function exactly once since
......@@ -326,7 +335,8 @@ class FP8GlobalStateManager:
cls,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
"""Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
......@@ -426,6 +436,8 @@ class FP8GlobalStateManager:
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls.reduce_and_update_fp8_tensors(forward=True)
@classmethod
......@@ -434,7 +446,8 @@ class FP8GlobalStateManager:
to ensure both forward steps are numerically same.
"""
if fp8_meta["recipe"].mxfp8():
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
......@@ -459,8 +472,8 @@ class FP8GlobalStateManager:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
if fp8_meta["recipe"].mxfp8():
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Store updated amaxes and scales from phase 1 post forward.
......@@ -478,8 +491,8 @@ class FP8GlobalStateManager:
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
if fp8_meta["recipe"].mxfp8():
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
......@@ -743,6 +756,8 @@ class RecipeState(abc.ABC):
cls = DelayedScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
return cls(
......@@ -813,6 +828,45 @@ class DelayedScalingRecipeState(RecipeState):
]
class Float8CurrentScalingRecipeState(RecipeState):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe: Float8CurrentScaling
mode: str
dtype: tex.DType
device: torch.device
def __init__(
self,
recipe: Float8CurrentScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
from .tensor.float8_tensor import Float8CurrentScalingQuantizer
return [
Float8CurrentScalingQuantizer(self.dtype, device=self.device)
for i in range(self.num_quantizers)
]
class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization.
......
......@@ -21,6 +21,7 @@ from ._common import _ParameterInitMeta
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8CurrentScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
)
......@@ -34,6 +35,7 @@ from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -430,7 +432,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
"""
Delayed scaling only.
Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
......@@ -489,6 +494,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if recipe.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
......@@ -851,6 +860,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
# FP8 current scaling does not support fused cast + dbias
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
......
......@@ -88,6 +88,9 @@ class _GroupedLinear(torch.autograd.Function):
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment