Unverified Commit 77fa1e59 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] Enabling Per-Tensor Current Scaling Recipe (#1471)



* check in per-tensor current scaling full recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

setup basics of current scaling quantizer in python level
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

add test case for current scaling dequantize
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

finish linear layer fwd bwd test, determined error with bf16
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

achieved zero tolerance for Linear by specify gemm use_split_accumulator config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

enable layernormlinear with current scaling, pass bitwise test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

refactor test case code
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

make current scaling quantizers distrbuted, pass distributed linear&layernormlinear tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

bug fix: use cached fp8 recipe in backward
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

fix layernorm_mlp with current scaling, fix activation_helper with current scaling
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

support detailed numerical settings from recipe to quantization kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

resolving MR comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

recipe naming
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, remove IS_CURRENT_SCALING template from kernels
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, make current scaling c++ test cases
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add current scaling to test_numerics.py, skip act recomp and grouped linear
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmark for quantizer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmarks for linear layer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix, typo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more mr comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* avoid potential race condition by not using from_blob to construct amax tensor in C++
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Debug linter warnings and license check
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug import error in FP8 tensor test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug compilation error with CUDA 12.1 for Turing
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, fix activation cast fusion
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments, add NVTEQuantizationParams for compute scale
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove is_current_scaling check totally from common folder
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove benchmarks, will contribute in another repo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* adjust cs default recipe config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust comments in test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Remove current scaling mode from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor current-scaling-specific logic in core C++ lib

Move amax and scale update functions out of casting functions, and put into dedicated current-scaling source file. Add general API for accessing quantization config object.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing header in C++ tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable test config with FP8 transpose on Blackwell
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error in C++ test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 2a95efd3
...@@ -39,6 +39,27 @@ class Format(Enum): ...@@ -39,6 +39,27 @@ class Format(Enum):
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
@dataclass(frozen=True)
class MMParams:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""
use_split_accumulator: bool = True
@dataclass(frozen=True)
class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
"""
power_2_scale: bool = False
amax_epsilon: float = 0.0
class Recipe: class Recipe:
""" """
Base recipe class. Base recipe class.
...@@ -52,6 +73,10 @@ class Recipe: ...@@ -52,6 +73,10 @@ class Recipe:
"""Whether the given recipe is delayed scaling.""" """Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling) return isinstance(self, DelayedScaling)
def float8_current_scaling(self):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
@dataclass() @dataclass()
class DelayedScaling(Recipe): class DelayedScaling(Recipe):
...@@ -161,6 +186,75 @@ class DelayedScaling(Recipe): ...@@ -161,6 +186,75 @@ class DelayedScaling(Recipe):
) )
@dataclass()
class Float8CurrentScaling(Recipe):
"""
Use the per-tensor current scaling factor strategy.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of gradient tensor dY
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
"""
fp8_format: Format = Format.HYBRID
fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return (
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, "
f"fp8_gemm_fprop={self.fp8_gemm_fprop}, "
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
@dataclass() @dataclass()
class MXFP8BlockScaling(Recipe): class MXFP8BlockScaling(Recipe):
""" """
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <algorithm>
#include <limits>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
namespace {
constexpr int amax_kernel_threads = 512;
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = 0.f;
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const InputType val = static_cast<InputType>(loader.separate()[i]);
__builtin_assume(max >= InputType{0.f});
if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
#if __CUDA_ARCH__ >= 800
max = __hmax(__habs(val), max);
#else // Turing
max = static_cast<__nv_bfloat16>(
fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max)));
#endif
} else if constexpr (std::is_same_v<InputType, __half>) {
max = __hmax(__habs(val), max);
} else {
max = fmaxf(fabsf(val), max);
}
}
}
// Reduce amax over block
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
atomicMaxFloat(amax, max);
}
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, sizeof(float), stream);
// Return immediately if tensor is empty
if (N == 0) {
return;
}
// Figure out alignment
auto align = CheckAlignment(N, nvec, input);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
// Figure out CUDA blocks
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
// Launch kernel
switch (align) {
case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
break;
}
}
// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace
} // namespace transformer_engine
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
using namespace transformer_engine;
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
const auto &input = *reinterpret_cast<const Tensor *>(input_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
to_string(input.scaling_mode));
NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
"Input tensor for amax computation must be unquantized, but got dtype=",
to_string(input.data.dtype));
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");
CheckInputTensor(input, "input_compute_amax");
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(output.amax.numel() == 1,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape=",
output.amax.shape, ")");
NVTE_CHECK(output.amax.dptr != nullptr,
"Output tensor for amax computation has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Output tensor for amax computation has invalid amax tensor "
"(expected FP32, got dtype=",
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax");
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
stream);); // NOLINT(*)
}
namespace transformer_engine {
namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales,
const float epsilon) {
float amax = *amax_ptr;
if (amax < epsilon) {
amax = epsilon;
}
float scale = 1.f;
if (isinf(amax) || amax == 0.f) {
*scale_ptr = scale;
return;
}
scale = max_fp8 / amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
// use fp32 max to represent the scale
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0);
__builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits);
}
*scale_ptr = scale;
}
} // namespace
} // namespace transformer_engine
void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_,
cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_scale_from_amax);
using namespace transformer_engine;
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(is_fp8_dtype(output.data.dtype),
"Tensor must be FP8, but got dtype=", to_string(output.data.dtype));
NVTE_CHECK(output.amax.numel() == 1,
"Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape,
")");
NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Tensor has invalid amax tensor (expected FP32, got dtype=",
to_string(output.amax.dtype), ")");
NVTE_CHECK(output.scale.numel() == 1,
"Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape,
")");
NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data");
NVTE_CHECK(output.scale.dtype == DType::kFloat32,
"Tensor has invalid scale tensor (expected FP32, got dtype=",
to_string(output.scale.dtype), ")");
// Check config
NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)");
const auto &config = *reinterpret_cast<const QuantizationConfig *>(config_);
// Maximum FP8 value
float max_fp8 = 0.f;
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8,
config.force_pow_2_scales, config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cstring>
#include <iostream> #include <iostream>
#include "common.h" #include "common.h"
...@@ -150,8 +151,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -150,8 +151,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const DType type = t.dtype(); const DType type = t.dtype();
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor");
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
...@@ -410,3 +410,79 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { ...@@ -410,3 +410,79 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream);
} }
} }
NVTEQuantizationConfig nvte_create_quantization_config() {
return new transformer_engine::QuantizationConfig;
}
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
}
}
...@@ -249,7 +249,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu ...@@ -249,7 +249,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
input.dtype(), InputType, input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output.dtype(), OutputType, output.dtype(), OutputType,
if (is_delayed_tensor_scaling(output.scaling_mode)) { if (is_tensor_scaling(output.scaling_mode)) {
// delayed scaling and current scaling are two variants of per-tensor scaling
constexpr const char *itype_name = TypeInfo<InputType>::name; constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name; constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType); constexpr size_t itype_size = sizeof(InputType);
...@@ -323,6 +325,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu ...@@ -323,6 +325,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks = const int num_blocks =
(DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType> cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>( <<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr), static_cast<const InputType *>(input.data.dptr),
......
...@@ -1054,8 +1054,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, ...@@ -1054,8 +1054,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
input.data.dtype, IType, input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>( VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
...@@ -1079,8 +1078,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp ...@@ -1079,8 +1078,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp
input->data.dtype, IType, input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>( VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(grad.data.dptr), reinterpret_cast<const IType *>(grad.data.dptr),
...@@ -1164,14 +1162,22 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, ...@@ -1164,14 +1162,22 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop,
Tensor *output, Tensor *dbias, Tensor *workspace, Tensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) { if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + // zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" on GPU with compute capability < 10.0."); " on GPU with compute capability < 10.0.");
} }
if (!IS_DACT) { switch (output->scaling_mode) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream); case NVTE_DELAYED_TENSOR_SCALING: {
} else { if (!IS_DACT) {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream); CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
} else {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
} }
} }
......
...@@ -844,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war ...@@ -844,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war
staging[warpid] = my_warp_max; staging[warpid] = my_warp_max;
} }
__syncthreads(); __syncthreads();
compute_t result = 0; compute_t result = 0.f;
if (warpid == 0) { if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max); result = warp_reduce_max<num_warps>(my_max);
......
...@@ -24,6 +24,16 @@ TE_DType = { ...@@ -24,6 +24,16 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}
AttnMaskTypes = ( AttnMaskTypes = (
"no_mask", "no_mask",
"padding", "padding",
......
...@@ -46,15 +46,22 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, ...@@ -46,15 +46,22 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) {
NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!");
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
// check for both quantizer & tensor type:
// mxfp8 tensor -> mxfp8 quantizer
// float8 tensor -> delayed scaling quantizer OR current scaling quantizer
// also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer
for (auto [check_type, check_quantizer_type, create_tensor, _] : for (auto [check_type, check_quantizer_type, create_tensor, _] :
detail::custom_types_converters) { detail::custom_types_converters) {
if (check_type(tensor.ptr())) { if (check_type(tensor.ptr())) {
NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) {
"Unexpected quantization params type."); continue;
}
auto x = create_tensor(tensor, my_quantizer.get()); auto x = create_tensor(tensor, my_quantizer.get());
return x; return x;
} }
} }
NVTE_CHECK(dynamic_cast<NoneQuantizer*>(my_quantizer.get()) != nullptr,
"Unexpected quantization params type.");
// Regular pyTorch tensor // Regular pyTorch tensor
at::Tensor torch_tensor = tensor.cast<at::Tensor>(); at::Tensor torch_tensor = tensor.cast<at::Tensor>();
......
...@@ -50,6 +50,9 @@ ...@@ -50,6 +50,9 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;
// Each tensor here is shape (N, ) holding all scaling // Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear // data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta { class FP8TensorMeta {
...@@ -136,6 +139,29 @@ class Float8Quantizer : public Quantizer { ...@@ -136,6 +139,29 @@ class Float8Quantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
}; };
class Float8CurrentScalingQuantizer : public Quantizer {
public:
at::Tensor scale;
at::Tensor scale_inv;
at::Tensor amax;
DType dtype;
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
int amax_reduction_size;
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer { class MXFP8Quantizer : public Quantizer {
public: public:
DType dtype; DType dtype;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common.h"
#include "extensions.h" #include "extensions.h"
#include "pybind.h" #include "pybind.h"
...@@ -24,7 +25,35 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int ...@@ -24,7 +25,35 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
auto [te_output, out] = auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); // for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions.");
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
} else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
}
return out; return out;
} }
......
...@@ -45,6 +45,29 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -45,6 +45,29 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
} }
if (te_output.numel() == 0) return out; if (te_output.numel() == 0) return out;
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
}
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
......
...@@ -24,6 +24,7 @@ namespace transformer_engine::pytorch { ...@@ -24,6 +24,7 @@ namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorBasePythonClass = nullptr; PyTypeObject *Float8TensorBasePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr;
...@@ -33,6 +34,8 @@ void init_float8_extension() { ...@@ -33,6 +34,8 @@ void init_float8_extension() {
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass = Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer"));
Float8TensorPythonClass = Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module = auto fp8_base_module =
......
...@@ -140,6 +140,123 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( ...@@ -140,6 +140,123 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
: Quantizer(quantizer) {
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group");
const c10::intrusive_ptr<dist_group_type> amax_reduction_group =
amax_reduction_group_obj.is_none()
? nullptr
: amax_reduction_group_obj.cast<c10::intrusive_ptr<dist_group_type>>();
const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast<int>();
this->amax = amax;
this->scale = scale;
this->dtype = type;
this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group;
this->amax_reduction_size = amax_reduction_size;
// fp8 current scaling specific quantization params
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
}
void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them)
tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
getTensorShape(scale));
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
// quantize output and its transpose
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
std::vector<int64_t> scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
//unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer
opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts);
py::object ret;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>(); this->dtype = quantizer.attr("dtype").cast<DType>();
} }
......
...@@ -12,7 +12,7 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww ...@@ -12,7 +12,7 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
if (input.scaling_mode() == NVTE_INVALID_SCALING) { if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle."); NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return; return;
} }
......
...@@ -23,7 +23,8 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer ...@@ -23,7 +23,8 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
if (transpose_valid) { if (transpose_valid) {
transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>(); transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>();
} }
// In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer
// whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING
auto ret = TensorWrapper(quantizer->get_scaling_mode()); auto ret = TensorWrapper(quantizer->get_scaling_mode());
ret.set_rowwise_data(data.data_ptr(), dtype, shape); ret.set_rowwise_data(data.data_ptr(), dtype, shape);
......
...@@ -21,6 +21,7 @@ namespace transformer_engine::pytorch { ...@@ -21,6 +21,7 @@ namespace transformer_engine::pytorch {
extern PyTypeObject *Float8TensorPythonClass; extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass; extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8QuantizerClass;
extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *MXFP8QuantizerClass;
...@@ -33,13 +34,17 @@ void init_mxfp8_extension(); ...@@ -33,13 +34,17 @@ void init_mxfp8_extension();
namespace detail { namespace detail {
inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }
inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass;
}
inline bool IsFloat8Tensor(PyObject *obj) { inline bool IsFloat8Tensor(PyObject *obj) {
return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass;
} }
inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
inline bool IsMXFP8Tensor(PyObject *obj) { inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
...@@ -61,9 +66,11 @@ inline bool IsFloatingPointType(at::ScalarType type) { ...@@ -61,9 +66,11 @@ inline bool IsFloatingPointType(at::ScalarType type) {
} }
constexpr std::array custom_types_converters = { constexpr std::array custom_types_converters = {
std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8Quantizer>), CreateQuantizer<Float8Quantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8CurrentScalingQuantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>)}; CreateQuantizer<MXFP8Quantizer>)};
} // namespace detail } // namespace detail
......
...@@ -21,7 +21,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module ...@@ -21,7 +21,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager from .fp8 import FP8GlobalStateManager
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -859,7 +859,10 @@ def _all_gather_fp8( ...@@ -859,7 +859,10 @@ def _all_gather_fp8(
# Quantize input tensor if needed # Quantize input tensor if needed
if not isinstance(input_, Float8TensorBase): if not isinstance(input_, Float8TensorBase):
assert isinstance(quantizer, Float8Quantizer) assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False) quantizer.set_usage(columnwise=False)
input_ = quantizer(input_) input_ = quantizer(input_)
...@@ -867,7 +870,7 @@ def _all_gather_fp8( ...@@ -867,7 +870,7 @@ def _all_gather_fp8(
# Construct output tensor # Construct output tensor
out: Float8TensorBase out: Float8TensorBase
if isinstance(quantizer, Float8Quantizer): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
dtype = torch.float32 dtype = torch.float32
device = "cuda" device = "cuda"
if isinstance(input_, Float8Tensor): if isinstance(input_, Float8Tensor):
...@@ -885,6 +888,9 @@ def _all_gather_fp8( ...@@ -885,6 +888,9 @@ def _all_gather_fp8(
out._transpose_invalid = True out._transpose_invalid = True
else: else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from input_ to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from input_ to out
out._scale_inv = input_._scale_inv out._scale_inv = input_._scale_inv
# Perform communication # Perform communication
...@@ -999,8 +1005,10 @@ def gather_along_first_dim( ...@@ -999,8 +1005,10 @@ def gather_along_first_dim(
out_shape = list(input_.size()) out_shape = list(input_.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# FP8 case # FP8 case: delayed scaling or current scaling
if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer): if isinstance(input_, Float8TensorBase) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
return _all_gather_fp8( return _all_gather_fp8(
input_, input_,
process_group, process_group,
......
...@@ -13,7 +13,13 @@ from typing import Callable, List, Optional, Dict, Any, Tuple, Union ...@@ -13,7 +13,13 @@ from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
)
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability
...@@ -198,6 +204,8 @@ class FP8GlobalStateManager: ...@@ -198,6 +204,8 @@ class FP8GlobalStateManager:
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
) -> None: ) -> None:
""" """
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules. The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global to call this function in order to append it's FP8 tensor into a global
...@@ -211,7 +219,8 @@ class FP8GlobalStateManager: ...@@ -211,7 +219,8 @@ class FP8GlobalStateManager:
wrapper. For non CG case, it's called from within the module. wrapper. For non CG case, it's called from within the module.
""" """
if fp8_meta["recipe"].mxfp8(): # delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return return
# Every module must call this function exactly once since # Every module must call this function exactly once since
...@@ -326,7 +335,8 @@ class FP8GlobalStateManager: ...@@ -326,7 +335,8 @@ class FP8GlobalStateManager:
cls, cls,
forward: bool = True, forward: bool = True,
) -> None: ) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer.""" """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for buffer_key, amax_buffer in cls.global_amax_buffer.items(): for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction. # Check for forward or backward reduction.
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
...@@ -426,6 +436,8 @@ class FP8GlobalStateManager: ...@@ -426,6 +436,8 @@ class FP8GlobalStateManager:
# FP8 weight modules are reduced at the end of the optimizer # FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated. # step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls.reduce_and_update_fp8_tensors(forward=True) cls.reduce_and_update_fp8_tensors(forward=True)
@classmethod @classmethod
...@@ -434,7 +446,8 @@ class FP8GlobalStateManager: ...@@ -434,7 +446,8 @@ class FP8GlobalStateManager:
to ensure both forward steps are numerically same. to ensure both forward steps are numerically same.
""" """
if fp8_meta["recipe"].mxfp8(): # delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return return
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
...@@ -459,8 +472,8 @@ class FP8GlobalStateManager: ...@@ -459,8 +472,8 @@ class FP8GlobalStateManager:
"""Switch to the copied scaling factors and amaxes from phase """Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs. 1 forward for indentical numerical outputs.
""" """
# delayed scaling only function, noop for any other recipe
if fp8_meta["recipe"].mxfp8(): if not fp8_meta["recipe"].delayed():
return return
# Store updated amaxes and scales from phase 1 post forward. # Store updated amaxes and scales from phase 1 post forward.
...@@ -478,8 +491,8 @@ class FP8GlobalStateManager: ...@@ -478,8 +491,8 @@ class FP8GlobalStateManager:
@staticmethod @staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run.""" """Restore latest scaling factors and amaxes after recompute forward run."""
# delayed scaling only function, noop for any other recipe
if fp8_meta["recipe"].mxfp8(): if not fp8_meta["recipe"].delayed():
return return
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
...@@ -743,6 +756,8 @@ class RecipeState(abc.ABC): ...@@ -743,6 +756,8 @@ class RecipeState(abc.ABC):
cls = DelayedScalingRecipeState cls = DelayedScalingRecipeState
elif recipe.mxfp8(): elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
else: else:
raise ValueError("{recipe.__class__.__name__} is not supported") raise ValueError("{recipe.__class__.__name__} is not supported")
return cls( return cls(
...@@ -813,6 +828,45 @@ class DelayedScalingRecipeState(RecipeState): ...@@ -813,6 +828,45 @@ class DelayedScalingRecipeState(RecipeState):
] ]
class Float8CurrentScalingRecipeState(RecipeState):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe: Float8CurrentScaling
mode: str
dtype: tex.DType
device: torch.device
def __init__(
self,
recipe: Float8CurrentScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
from .tensor.float8_tensor import Float8CurrentScalingQuantizer
return [
Float8CurrentScalingQuantizer(self.dtype, device=self.device)
for i in range(self.num_quantizers)
]
class MXFP8BlockScalingRecipeState(RecipeState): class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization. """Configuration for MXFP8 quantization.
......
...@@ -21,6 +21,7 @@ from ._common import _ParameterInitMeta ...@@ -21,6 +21,7 @@ from ._common import _ParameterInitMeta
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8CurrentScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
) )
...@@ -34,6 +35,7 @@ from ..constants import dist_group_type ...@@ -34,6 +35,7 @@ from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -430,7 +432,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -430,7 +432,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super().__setattr__(name, value) super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`. """
Delayed scaling only.
Increase or decrease size of amax history based on given `length`.
.. warning:: .. warning::
This changes the underlying amax memory location. This changes the underlying amax memory location.
...@@ -489,6 +494,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -489,6 +494,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return return
if recipe.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd # 2 (grad_output and grad_input) for bwd
...@@ -851,6 +860,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -851,6 +860,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.use_bias: if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
# FP8 current scaling does not support fused cast + dbias
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
......
...@@ -88,6 +88,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -88,6 +88,9 @@ class _GroupedLinear(torch.autograd.Function):
# TODO Support MXFP8 # pylint: disable=fixme # TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8") raise NotImplementedError("GroupedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weights[0].shape[-1] in_features = weights[0].shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment