Unverified Commit 1bbeab1c authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise float8 quantizer and quantized tensor class (#1513)



* Blockwise float8 quantizer and quantized tensor class.

The classes are configurable for 128x128 blocksize
and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively.

Scale tensors are stored in a format emenable for matrix multiplication,
however the integration of matmul is deferred as a separate story.

Fusions of quantization and DBIAS or activation functions are not yet
implemented, and the dequantization is currently implemented in torch.

Tests for quantization are included in C++ and pytorch layers, with
exact comparison to reference quantizer behavior as well as an attempt
to hit interesting branches through the API such as tensor creation
in pytorch and CPP and dequantization of row and columnwise usage.

Two CUDA kernels for quantization are included, and are direct ports
of equivalents in the kitchen repository, where a subchannel recipe
has been used for end to end training.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Apply linting changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Alignment for 1D scaling for GEMM edge case.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix merge conflict with name change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use common tensor map API.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API to use two scaling mode enums.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update some call sites.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Tests for torch tensor API surface.

Since the quantized tensor is a tensor
subclass, these tests exercise torch hooks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse scale calculation between quantizer refs.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Save memory by dropping reference to saved tensors.

Issues previously observed are solved.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove constexpr parameters from kernel.

Code size is reduced with fewer constexpr params.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Merge conflict from rebase.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add shape implementations for block scaling.

nvte_shape was added upstream. Logic added
for block scaled fp8.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Move benchmark to te_playground
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove amax_epsilon and pow_2_scales from tensor.

Hardcodes the default values.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Lint changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fixup MR changes that broke.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Safer ifdef in kernel.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Documentation prose.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse compute_scale function from Current Scaling.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Bugfix on inf_value scale refactor.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove qopt calls from test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytest list.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to reference scale calc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use ptx.cuh functions instead of cde.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update shape logic with allocation and reuse shape.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Usage defaults MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Copyright and header guard.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Updating torch dispatch code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix exception type.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use TypeInfo
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update CS scale update test to use updated ref impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX scaling mode enum
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Skip tests on Lovelace
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 3e305f72
...@@ -30,6 +30,8 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m ...@@ -30,6 +30,8 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
......
...@@ -11,6 +11,7 @@ add_executable(test_operator ...@@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_mxfp8.cu test_cast_mxfp8.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
test_cast_transpose.cu test_cast_transpose.cu
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <random> #include <random>
#include <iostream>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <string> #include <string>
...@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul,4ul}; auto block_alignment = std::vector<size_t>{128ul, 4ul};
{ {
auto alignment = block_alignment[0]; auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
static_cast<size_t>(1)),
alignment) * alignment;
alignment = block_alignment[1]; alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
static_cast<size_t>(32)),
alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto alignment = block_alignment[1]; auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
static_cast<size_t>(32)),
alignment) * alignment;
alignment = block_alignment[0]; alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
static_cast<size_t>(1)),
alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
...@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
NVTE_ERROR("Invalid scaling mode!"); NVTE_ERROR("Invalid scaling mode!");
} }
...@@ -171,7 +216,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -171,7 +216,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name, Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type, const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise, const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode) { const NVTEScalingMode &scaling_mode,
const QuantizationOptions* q_opts) {
name_ = name; name_ = name;
const size_t seed = create_seed_from_tensor_name(name); const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed); gen_.seed(seed);
...@@ -198,7 +244,7 @@ Tensor::Tensor(const std::string& name, ...@@ -198,7 +244,7 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape{nullptr, 0}; NVTEShape columnwise_shape{nullptr, 0};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling // Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) { for (size_t i = 0; i < shape.ndim - 1; ++i) {
...@@ -259,27 +305,33 @@ Tensor::Tensor(const std::string& name, ...@@ -259,27 +305,33 @@ Tensor::Tensor(const std::string& name,
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
} }
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, auto [rowwise_scale_meta, colwise_scale_meta] =
tensor_.scaling_mode()); get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) { if (rowwise) {
cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); auto scale_dtype = rowwise_scale_meta.type;
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
} }
if (columnwise) { if (columnwise) {
cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*)
cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); auto scale_dtype = colwise_scale_meta.type;
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
} }
} }
if (q_opts != nullptr) {
NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation.");
NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation.");
}
} }
} }
...@@ -311,7 +363,8 @@ void Tensor::to_cpu() const { ...@@ -311,7 +363,8 @@ void Tensor::to_cpu() const {
sizeof(float), sizeof(float),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
...@@ -349,7 +402,8 @@ void Tensor::from_cpu() const { ...@@ -349,7 +402,8 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
...@@ -383,27 +437,29 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -383,27 +437,29 @@ void Tensor::set_scale_inv(float scale_inv) {
if (columnwise_) { if (columnwise_) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_); NVTE_CHECK(columnwise_scale_inv_cpu_data_);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape); auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1){ if (num_scales == 1) {
rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv; rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{ } else {
std::uniform_int_distribution<uint8_t> dis(0, 127); std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>(); auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){ for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_); scale_inv_ptr[i] = dis(gen_);
} }
} }
} }
if (columnwise_) { if (columnwise_) {
auto num_scales = product(colwise_scale_meta.shape); auto num_scales = product(colwise_scale_meta.shape);
if (num_scales == 1){ if (num_scales == 1) {
columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv; columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{ } else {
std::uniform_int_distribution<uint8_t> dis(0, 127); std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>(); auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){ for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_); scale_inv_ptr[i] = dis(gen_);
} }
} }
...@@ -413,23 +469,20 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -413,23 +469,20 @@ void Tensor::set_scale_inv(float scale_inv) {
} }
void Tensor::shareFP8Meta(const Tensor &other) { void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data(); auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
static_cast<DType>(my_rowwise_data.dtype),
my_rowwise_data.shape); my_rowwise_data.shape);
auto my_columnwise_data = tensor_.get_columnwise_data(); auto my_columnwise_data = tensor_.get_columnwise_data();
new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
static_cast<DType>(my_columnwise_data.dtype), static_cast<DType>(my_columnwise_data.dtype),
my_columnwise_data.shape); my_columnwise_data.shape);
auto other_amax = other.tensor_.get_amax(); auto other_amax = other.tensor_.get_amax();
new_tensor.set_amax(other_amax.data_ptr, new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
static_cast<DType>(other_amax.dtype),
other_amax.shape); other_amax.shape);
auto other_scale = other.tensor_.get_scale(); auto other_scale = other.tensor_.get_scale();
new_tensor.set_scale(other_scale.data_ptr, new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
static_cast<DType>(other_scale.dtype),
other_scale.shape); other_scale.shape);
auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
...@@ -460,9 +513,7 @@ std::string to_string(const std::vector<T> &v) { ...@@ -460,9 +513,7 @@ std::string to_string(const std::vector<T> &v) {
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) { std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
std::vector<size_t> ret; std::vector<size_t> ret;
size_t current_i = i; size_t current_i = i;
for (size_t current = shape.ndim - 1; for (size_t current = shape.ndim - 1; current > 0; --current) {
current > 0;
--current) {
ret.push_back(current_i % shape.data[current]); ret.push_back(current_i % shape.data[current]);
current_i /= shape.data[current]; current_i /= shape.data[current];
} }
...@@ -767,8 +818,7 @@ bool isFp8Type(DType type) { ...@@ -767,8 +818,7 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
} }
int32_t getDeviceComputeCapability() int32_t getDeviceComputeCapability() {
{
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0); cudaGetDeviceProperties(&deviceProp, 0);
return 10 * deviceProp.major + deviceProp.minor; return 10 * deviceProp.major + deviceProp.minor;
......
...@@ -95,21 +95,29 @@ struct TypeInfo{ ...@@ -95,21 +95,29 @@ struct TypeInfo{
constexpr static size_t size = sizeof(T); constexpr static size_t size = sizeof(T);
}; };
struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};
class Tensor { class Tensor {
public: public:
Tensor(const std::string& name, Tensor(const std::string& name,
const NVTEShape &shape, const DType type, const NVTEShape &shape, const DType type,
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr);
Tensor(const std::string& name, Tensor(const std::string& name,
const std::vector<size_t> &shape, const std::vector<size_t> &shape,
const DType type, const DType type,
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} const QuantizationOptions* q_opts = nullptr) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {}
Tensor() {} Tensor() {}
...@@ -136,25 +144,19 @@ class Tensor { ...@@ -136,25 +144,19 @@ class Tensor {
if (scale_inv != nullptr) { if (scale_inv != nullptr) {
cudaFree(scale_inv); cudaFree(scale_inv);
} }
if (columnwise_data_ptr != nullptr){ if (columnwise_data_ptr != nullptr) {
cudaFree(columnwise_data_ptr); cudaFree(columnwise_data_ptr);
} }
if (columnwise_scale_inv != nullptr){ if (columnwise_scale_inv != nullptr) {
cudaFree(columnwise_scale_inv); cudaFree(columnwise_scale_inv);
} }
} }
NVTETensor data() const noexcept { NVTETensor data() const noexcept { return tensor_.data(); }
return tensor_.data();
}
NVTEShape rowwise_shape() const noexcept { NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; }
return tensor_.get_rowwise_data().shape;
}
NVTEShape columnwise_shape() const noexcept { NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; }
return tensor_.get_columnwise_data().shape;
}
NVTEShape rowwise_scale_inv_shape() const { NVTEShape rowwise_scale_inv_shape() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
...@@ -221,6 +223,8 @@ class Tensor { ...@@ -221,6 +223,8 @@ class Tensor {
T *rowwise_cpu_scale_inv_ptr(){ T *rowwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -232,6 +236,8 @@ class Tensor { ...@@ -232,6 +236,8 @@ class Tensor {
T *columnwise_cpu_scale_inv_ptr(){ T *columnwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -459,6 +465,7 @@ extern std::vector<DType> all_fp_types; ...@@ -459,6 +465,7 @@ extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type); bool isFp8Type(DType type);
int32_t getDeviceComputeCapability(); int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100; constexpr int32_t blackwellComputeCapability = 100;
} // namespace test } // namespace test
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import dataclasses
import math
import torch
from typing import Optional, Protocol, Tuple
from references.quantize_scale_calc import scale_from_amax_tensor
@dataclasses.dataclass()
class QuantizeResult:
data: torch.Tensor
scale: torch.Tensor
data_t: Optional[torch.Tensor]
scale_t: Optional[torch.Tensor]
@dataclasses.dataclass()
class CuBLASScaleMunger:
def munge_scale_shapes_for_backend(
self,
unmunged: QuantizeResult,
tile_shape: Tuple[int, int],
) -> QuantizeResult:
"""
cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed
so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4))
For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4))
format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required
"""
def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor:
if transpose:
s = s.transpose(-1, -2).contiguous()
M, K = s.shape
if K % 4 == 0:
return s
k_pad = 4 - (K % 4)
return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous()
s = _pad_inner_to_align(unmunged.scale, transpose=tile_shape[0] == 1)
if unmunged.scale_t is None:
s_t = None
else:
s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1)
return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t)
def demunge_scale_shape_from_backend(
cls,
qtensor_shape: Tuple[int, int],
scales: torch.Tensor,
tile_shape: Tuple[int, int],
) -> torch.Tensor:
"""
Inverse operation of munge_scale_shapes_for_backend
"""
if tile_shape[0] != 1:
# 2D block quantized tensor may need padding stripped off
derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1])
else:
derived_scale_k_shape = qtensor_shape[0]
M, K = scales.shape
if derived_scale_k_shape != K:
scales = scales[:, :derived_scale_k_shape].contiguous()
if tile_shape[0] == 1:
return scales.transpose(-1, -2).contiguous()
else:
return scales
@dataclasses.dataclass()
class BlockwiseQuantizerReference:
"""
A reference QuantizeOp for subchannel/block hybrid quantization.
Defers to ref GEMMs and quantizization formatting based on the backend.
"""
def __init__(self) -> None:
self.scale_munger = CuBLASScaleMunger()
@classmethod
def _quantize_square_block_tiling(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
return_transpose: bool,
pow_2_scales: bool,
eps: float,
) -> QuantizeResult:
M, K = x.shape
pad_m_k = [0, 0]
if K % tile_len != 0:
pad_m_k[1] = tile_len - (K % tile_len)
if M % tile_len != 0:
pad_m_k[0] = tile_len - (M % tile_len)
unpadded_m, unpadded_k = M, K
if pad_m_k[0] != 0 or pad_m_k[1] != 0:
x = torch.nn.functional.pad(
x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0
).contiguous()
M, K = x.shape
x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len)
amax_grid = (
torch.abs(x_tiled.transpose(-3, -2))
.reshape(M // tile_len, K // tile_len, tile_len**2)
.amax(dim=-1)
).float()
dtype_max = torch.finfo(quant_dtype).max
scale, scale_inv, _ = scale_from_amax_tensor(
x_dtype=x.dtype,
amax=amax_grid,
quant_dtype=quant_dtype,
pow_2_scales=pow_2_scales,
eps=eps,
)
qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
if unpadded_k != K or unpadded_m != M:
qx = qx[:unpadded_m, :unpadded_k].contiguous()
if return_transpose:
# Valid because of square block sizes
qx_t = qx.transpose(-1, -2).contiguous()
scale_inv_t = scale_inv.transpose(-1, -2).contiguous()
else:
qx_t = None
scale_inv_t = None
return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t)
@classmethod
def _quantize_vectorwise_reference(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
pow_2_scales: bool,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
M, K = x.shape
dtype_max = torch.finfo(quant_dtype).max
x_tiled = x.reshape(M, K // tile_len, tile_len)
amax_grid = torch.abs(x_tiled).amax(dim=-1).float()
scale, scale_inv, _ = scale_from_amax_tensor(
x_dtype=x.dtype,
amax=amax_grid,
quant_dtype=quant_dtype,
pow_2_scales=pow_2_scales,
eps=eps,
)
qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
return qx, scale_inv
@classmethod
def _quantize_vector_tiling(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
return_transpose: bool,
pow_2_scales: bool,
eps: float,
) -> QuantizeResult:
M, K = x.shape
if K % tile_len == 0:
qref_input = x
else:
pad_amount = tile_len - (K % tile_len)
pad = (0, pad_amount)
qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0)
qout_padded, scale_inv = cls._quantize_vectorwise_reference(
qref_input,
quant_dtype,
tile_len=tile_len,
pow_2_scales=pow_2_scales,
eps=eps,
)
if K % tile_len == 0:
qout = qout_padded
else:
qout = qout_padded[:, :K].contiguous()
if return_transpose:
if M % tile_len == 0:
qref_input = x.transpose(-1, -2).contiguous()
else:
amount_to_pad = tile_len - (M % tile_len)
pad = (0, amount_to_pad)
qref_input = torch.nn.functional.pad(
x.transpose(-1, -2), pad, mode="constant", value=0
).contiguous()
qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference(
qref_input,
quant_dtype,
tile_len=tile_len,
pow_2_scales=pow_2_scales,
eps=eps,
)
if M % tile_len == 0:
qout_t = qout_t_padded
else:
qout_t = qout_t_padded[:, :M].contiguous()
else:
qout_t, scale_inv_t = None, None
return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t)
def ref_dequantize_rowwise(
self,
q: torch.Tensor,
quant_tile_shape: Tuple[int, int],
s: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
assert q.dim() == 2
q_M, q_K = q.shape
s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape)
assert len(s.shape) == 2
m_tiles, k_tiles = s.shape
M, K = q.shape
unpadded_m, unpadded_k = M, K
if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0:
m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0]
k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1]
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0
).contiguous()
M, K = q.shape
q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1])
result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1)
result = result.view(M, K).to(dtype)
if M != unpadded_m or K != unpadded_k:
result = result[:unpadded_m, :unpadded_k].contiguous()
return result
def quantize(
self,
x: torch.Tensor,
quant_dtype: torch.dtype,
return_transpose: bool = False,
eps: float = 0.0,
pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128),
) -> QuantizeResult:
# sanity checks
assert x.dim() == 2
assert x.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
torch.float32,
), "Unsupported input dtype."
assert quant_dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1:
# Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_vector_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[1],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
quant_tile_shape,
)
else:
# Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_square_block_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[0],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
quant_tile_shape,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
def scale_from_amax_tensor(
x_dtype: torch.dtype,
amax: torch.Tensor,
quant_dtype: torch.dtype,
*,
eps: float,
pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
...@@ -6,63 +6,16 @@ import torch ...@@ -6,63 +6,16 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch from transformer_engine.pytorch.constants import TE_DType_To_Torch
from references.quantize_scale_calc import scale_from_amax_tensor
# Compute scale and scale_inv from amax
def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale)
# option2: when scale is inf, set scale to 1
scale = torch.where(scale == torch.inf, 1.0, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv
# compute amax and scale # compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32) x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1) amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor." return scale_from_amax_tensor(
fp8_max = torch.finfo(quant_dtype).max torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales
)
scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
return scale, scale_inv, amax
def _multi_dim_transpose(tensor): def _multi_dim_transpose(tensor):
...@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast( ...@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast(
qx_t = _multi_dim_transpose(qx) qx_t = _multi_dim_transpose(qx)
sx_t = sx sx_t = sx
return qx, sx, qx_t, sx_t return qx, sx, qx_t, sx_t
def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import math
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
QuantizeResult,
)
# TODO replace with call to fp8.py when recipe added.
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."
def initialize_for_many_scales(
x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str
) -> torch.Tensor:
"""
Put separate distributions into each quantization tile
to avoid many tiles having similar scale values and
causing false passes.
"""
tile_grid_shape = (
math.ceil(x_shape_2d[0] / tile_shape[0]),
math.ceil(x_shape_2d[1] / tile_shape[1]),
)
# Arbitrary size
max_val = 8192.0
# Make a uniform distribution of [-max_val, max_val]
tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val
result = torch.empty(x_shape_2d, dtype=dtype, device=device)
tile_elements = tile_shape[0] * tile_shape[1]
for i in range(tile_grid_shape[0]):
for j in range(tile_grid_shape[1]):
target = tile_extrema[i, j].item()
step = target / (tile_elements)
if target == 0:
tile = torch.zeros(tile_shape, dtype=dtype, device=device)
else:
tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device)
tile = tile.reshape(*tile_shape)
min_dst_vals = (i * tile_shape[0], j * tile_shape[1])
max_dst_vals = (
min((i + 1) * tile_shape[0], x_shape_2d[0]),
min((j + 1) * tile_shape[1], x_shape_2d[1]),
)
max_src_vals = (
max_dst_vals[0] - min_dst_vals[0],
max_dst_vals[1] - min_dst_vals[1],
)
result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[
: max_src_vals[0], : max_src_vals[1]
]
return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
elif tile_size == (128, 128):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=block_scaling_dim,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=return_transpose,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask = torch.ones(
(math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device
)
scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend(
QuantizeResult(qx, scale_mask, None, None), tile_size
).scale
sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
scale_mask = torch.ones(
(math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])),
device=sx_t.device,
)
scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend(
QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale
sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
else:
# should be None
assert qx_t is None and qx_t_ref is None
assert sx_t is None and sx_t_ref is None
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)])
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
def test_quantization_block_tiling_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
tile_size: Tuple[int, int],
extrema_high: bool,
) -> None:
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
elif tile_size == (128, 128):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=block_scaling_dim,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
return_transpose = False
# Input
if extrema_high:
x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
else:
x = torch.zeros((M, N), dtype=x_dtype, device=device)
# Run cast and transpose kernel
# Internal call ops.quantize_tensorwise
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
sx = x_fp8_sut._rowwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=return_transpose,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
)
qx_ref, sx_ref = (
qresult_ref.data,
qresult_ref.scale,
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0)
if extrema_high:
expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max
if pow_2_scales:
expected_value = math.floor(math.log2(expected_value))
expected_value = math.pow(2.0, expected_value)
expected_value = 1 / expected_value
elif not extrema_high and eps == 0:
expected_value = 1.0
else:
assert not extrema_high
# eps is small enough to trigger inf in quant_dtype_max / eps
if pow_2_scales:
expected_value = math.pow(2.0, -127)
else:
expected_value = 1 / torch.finfo(x_dtype).max
torch.testing.assert_close(
sx.flatten()[0],
torch.tensor(expected_value, device=sx.device),
atol=0.0,
rtol=0.0,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
import io
import math
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08),
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125),
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# TODO replace with call to fp8.py when recipe added.
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFloat8BlockwiseTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
dtype: torch.dtype = torch.float32,
is_2D_scaled: bool = True,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
rowwise = True
columnwise = True
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=rowwise,
columnwise=columnwise,
block_scaling_dim=2 if is_2D_scaled else 1,
)
scale_dims = quantizer.get_scale_shape(dims, columnwise=False)
columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True)
columnwise_dims = quantizer.get_columnwise_shape(dims)
tensor = Float8BlockwiseQTensor(
shape=dims,
dtype=dtype,
rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32),
columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8),
columnwise_scale_inv=torch.zeros(
columnwise_scale_dims, device="cuda", dtype=torch.float32
),
fp8_dtype=fp8_dtype,
is_2D_scaled=is_2D_scaled,
quantizer=quantizer,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
quantizer: Float8BlockQuantizer,
dtype: torch.dtype = torch.float32,
dims: DimsType = (23, 128),
rtol: float = 0.0,
atol: float = 0.0,
dequant_columnwise: bool = False,
use_cpp_allocation: bool = False,
) -> None:
"""Check numerical error when casting to FP8 and back"""
dims = _to_list(dims)
# Initialize random data
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_ref_cuda = x_ref.to("cuda")
# Cast to FP8 and back
if not use_cpp_allocation:
x_fp8 = quantizer.make_empty(shape=dims, device="cuda")
quantizer.update_quantized(x_ref_cuda, x_fp8)
else:
# This codepath allows the CPP binding to allocate the output
# tensor
x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None)
if dequant_columnwise:
# Strip out rowwise data to verify dequantization of
# columnwise data.
x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True)
x_fp8 = x_fp8.dequantize(dtype=dtype).cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol)
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_quantize_dequantize_dtypes(
self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=False,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
def test_quantize_dequantize_dims(
self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool
) -> None:
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dq_columnwise", [True, False])
def test_quantize_dequantize_dims_cpp_allocate_output(
self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
use_cpp_allocation=True,
)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None:
"""Test data accessors of Float8BlockwiseQTensor"""
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
fp8_dtype = tex.DType.kFloat8E4M3
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Create FP8 tensor
x_fp8 = quantizer.quantize(x_hp)
x_recovered = x_fp8.data
torch.testing.assert_close(x_recovered, x_hp, **_tols[fp8_dtype])
x_fp8.data = y_hp
y_recovered = x_fp8.data
torch.testing.assert_close(y_recovered, y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
"""Test serialization of Float8BlockwiseQTensor"""
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Create FP8 tensor
x_fp8 = quantizer.quantize(x_hp)
# Save tensor
buffer = io.BytesIO()
torch.save(x_fp8, buffer)
# Load tensor
buffer.seek(0)
x_fp8_loaded = torch.load(buffer, weights_only=False)
# Test that loaded tensor matches original
assert isinstance(x_fp8_loaded, Float8BlockwiseQTensor)
torch.testing.assert_close(x_fp8_loaded._rowwise_data, x_fp8._rowwise_data)
torch.testing.assert_close(x_fp8_loaded._columnwise_data, x_fp8._columnwise_data)
torch.testing.assert_close(x_fp8_loaded._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
torch.testing.assert_close(x_fp8_loaded._columnwise_scale_inv, x_fp8._columnwise_scale_inv)
torch.testing.assert_close(x_fp8_loaded.data, x_fp8.data)
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
# Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize()
x_fp8_loaded_dequant = x_fp8_loaded.dequantize()
torch.testing.assert_close(x_fp8_loaded_dequant, x_fp8_dequant)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_inplace_ops(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test in-place operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Test in-place add
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.add_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp + y_hp, **_tols[fp8_dtype])
# Test in-place subtract
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.sub_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp - y_hp, **_tols[fp8_dtype])
# Test in-place multiply
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.mul_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp * y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_out_of_place_ops(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test out-of-place operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
# Test exact operations
torch.testing.assert_close(-x_fp8, -x_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8.abs(), x_hp.abs(), **_tols[fp8_dtype])
# Test elementwise operations
torch.testing.assert_close(x_fp8 + y_fp8, x_hp + y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8 - y_fp8, x_hp - y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8 * y_fp8, x_hp * y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_hp), **_tols[fp8_dtype])
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_hp - y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_view_same_shape(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view with same shape
x_view = x_fp8.view(*dims)
torch.testing.assert_close(x_view.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_view.shape == x_fp8.shape, "Shape changed after view with same dims"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_reshape_same_shape(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test reshape operations that preserve tensor shape"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test reshape with same shape
x_reshape = x_fp8.reshape(*dims)
torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with same dims"
# Test reshape with -1 canonicalization
new_dims = [-1, dims[1]]
x_reshape = x_fp8.reshape(*new_dims)
torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with -1"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_reshape.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_clone_detach(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test clone and detach operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.quantize(x_hp.clone())
# Test clone
x_clone = x_fp8.clone()
torch.testing.assert_close(x_clone.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_clone.shape == x_fp8.shape, "Shape changed after clone"
# Test detach
x_detach = x_fp8.detach()
torch.testing.assert_close(x_detach.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_detach.shape == x_fp8.shape, "Shape changed after detach"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_clone.dequantize(), -x_hp, **_tols[fp8_dtype])
...@@ -9,7 +9,7 @@ import transformer_engine.pytorch as te ...@@ -9,7 +9,7 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax from references.quantize_scale_calc import scale_from_amax_tensor
input_size_pairs = [ input_size_pairs = [
...@@ -224,17 +224,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, ...@@ -224,17 +224,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55]) @pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("pow_2_scales", [False, True]) @pytest.mark.parametrize("pow_2_scales", [False, True])
@pytest.mark.parametrize("epsilon", [0.0, 100.0]) @pytest.mark.parametrize("epsilon", [0.0, 100.0])
def test_multi_tensor_compute_scale_and_scale_inv( def test_multi_tensor_compute_scale_and_scale_inv(
input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon input_size_pair, applier, repeat, fp8_dtype, pow_2_scales, epsilon
): ):
sizea, sizeb = input_size_pair sizea, sizeb = input_size_pair
device = torch.device("cuda") device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
a = torch.randn([sizea], dtype=torch.float32, device=device).abs() a = torch.randn([sizea], dtype=torch.float32, device=device).abs()
b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() b = torch.randn([sizeb], dtype=torch.float32, device=device).abs()
max_fp8 = torch.finfo(fp8_dtype).max
amax_list = [] amax_list = []
for i in range(repeat): for i in range(repeat):
...@@ -253,8 +254,8 @@ def test_multi_tensor_compute_scale_and_scale_inv( ...@@ -253,8 +254,8 @@ def test_multi_tensor_compute_scale_and_scale_inv(
) )
for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list):
scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( scale_ref, scale_inv_ref, _ = scale_from_amax_tensor(
amax, max_fp8, epsilon, pow_2_scales torch.float32, amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales
) )
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
...@@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES ...@@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
......
...@@ -99,6 +99,12 @@ struct Tensor { ...@@ -99,6 +99,12 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
private:
// Used as an allocation for nvte_tensor_shape
// if the shape has to be inferred from columnwise data.
mutable std::vector<size_t> rowwise_shape_cache;
public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
Tensor() Tensor()
...@@ -160,12 +166,39 @@ struct Tensor { ...@@ -160,12 +166,39 @@ struct Tensor {
return data.shape; return data.shape;
} }
break; break;
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> shape;
size_t ndim = columnwise_data.shape.size();
shape.reserve(ndim);
for (size_t i = 0; i + 1 < ndim; ++i) {
shape.push_back(columnwise_data.shape[i + 1]);
}
if (ndim > 0) {
shape.push_back(columnwise_data.shape[0]);
}
return shape;
} else {
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return data.shape;
}
break;
}
default: default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {}; return {};
} }
} }
const std::vector<size_t> &rowwise_shape_ref() const {
rowwise_shape_cache = shape();
return rowwise_shape_cache;
}
/*! Matrix height after tensor is flattened to 2D /*! Matrix height after tensor is flattened to 2D
* *
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
...@@ -247,6 +280,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) ...@@ -247,6 +280,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif #endif
#undef TRANSFORMER_ENGINE_TYPE_NAME #undef TRANSFORMER_ENGINE_TYPE_NAME
template <typename T>
struct TypeExtrema;
template <>
struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f;
};
template <>
struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f;
};
template <>
struct TypeExtrema<bf16> {
// Hex float format of 1.(7 bits of 1) * 2 ^ 127
static constexpr float max = 0x1.FEp127;
};
template <>
struct TypeExtrema<fp16> {
// Hex float format of 1.(10 bits of 1) * 2 ^ 15
static constexpr float max = 0x1.FFCp15;
};
template <typename T>
struct TypeExtrema {
static constexpr float max = std::numeric_limits<T>::max();
};
} // namespace detail } // namespace detail
template <typename T> template <typename T>
...@@ -277,6 +340,7 @@ struct TypeInfo { ...@@ -277,6 +340,7 @@ struct TypeInfo {
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = sizeof(T);
constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
constexpr static const char *name = detail::type_name<T>(); constexpr static const char *name = detail::type_name<T>();
}; };
......
...@@ -81,6 +81,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -81,6 +81,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
const transformer_engine::Tensor &B, const cublasOperation_t transB, const transformer_engine::Tensor &B, const cublasOperation_t transB,
const int k, const int lda, const int ldb) { const int k, const int lda, const int ldb) {
using namespace transformer_engine; using namespace transformer_engine;
// FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design.
// Must either force them both into a common block scaling mode or loosen this
// restriction.
NVTE_CHECK(A.scaling_mode == B.scaling_mode, NVTE_CHECK(A.scaling_mode == B.scaling_mode,
"Inputs A and B to GEMM need to have the same scaling mode!"); "Inputs A and B to GEMM need to have the same scaling mode!");
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
...@@ -90,6 +93,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -90,6 +93,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.lda = lda; ret.lda = lda;
ret.ldb = ldb; ret.ldb = ldb;
// FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases
// or need to be treated as `is_tensor_scaling`.
if (is_tensor_scaling(A.scaling_mode)) { if (is_tensor_scaling(A.scaling_mode)) {
ret.A = A.data.dptr; ret.A = A.data.dptr;
ret.A_scale_inv = A.scale_inv.dptr; ret.A_scale_inv = A.scale_inv.dptr;
...@@ -244,6 +249,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -244,6 +249,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode))); &fastAccuMode, sizeof(fastAccuMode)));
// FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized
// GEMM types.
// Scaling factors. // Scaling factors.
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode; cublasLtMatmulMatrixScale_t scaling_mode;
......
...@@ -17,22 +17,31 @@ ...@@ -17,22 +17,31 @@
extern "C" { extern "C" {
#endif #endif
/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) /* Quantize the tensor
* The implementation is per the microscaling format MXFP8 defined by the OCP specification: *
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor.
*
* Supported formats are:
*
* 1) MXFP8 scaling (for compute capability 10.0 or newer)
*
* The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification:
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
* *
* Supported modes of scaling (live scaling): *
* 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: * Supported modes of MXFP8 scaling (live scaling) for scaling mode NVTE_MXFP8_1D_SCALING
* a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* - the scaled output tensor * - the scaled output tensor
* - the corresponding scaling factors * - the corresponding scaling factors
* The scaling factors are computed for blocks of the shape [1,32] * The scaling factors are computed for blocks of the shape [1,32]
* (i.e., each scaling factor spans 32 contiguous elements along rows). * (i.e., each scaling factor spans 32 contiguous elements along rows).
* *
* 2) Columwise scaling (along the dim=1) computes one set of the output data. * b) Columwise scaling (along the dim=1) computes one set of the output data.
* The scaling factors are computed for blocks of the shape [32,1] * The scaling factors are computed for blocks of the shape [32,1]
* (i.e., each scaling factor spans 32 contiguous elements along columns). * (i.e., each scaling factor spans 32 contiguous elements along columns).
* *
* 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) * c) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
* computes two sets of the output data: both 1) and 2). * computes two sets of the output data: both 1) and 2).
* *
* The shape of the MX block must be specified in the 'output' argument, * The shape of the MX block must be specified in the 'output' argument,
...@@ -40,25 +49,53 @@ extern "C" { ...@@ -40,25 +49,53 @@ extern "C" {
* *
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* of the output tensor should be set to 0. * of the output tensor should be set to 0.
*
* 2) NVTE_DELAYED_TENSOR_SCALING that quantize the entire tensor
* using a single scaling factor. The absolute maximum value of the tensor should
* be precalculated either online (current scaling) or based on a tensor history
* (delayed scaling). The calls to nvte_quantize scale based on that data value.
* Note the NVTE_DELAYED_TENSOR_SCALING NVTEScalingMode is reused for online
* per tensor scaling.
*
*
* 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D
* for compute capability of at least 9.0. These modes quantize the tensor by blocks
* of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively.
*
* The supported modes are:
* a) Rowwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with identical shape to the
* input tensor.
* - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks.
* b) Columnwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with a shape equivalent to
* the transpose of the input tensor.
* - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks
* of the input tensor.
* c) Both: In which both tensors and both scales are calculated.
*
* This quantization mode includes both the calculation of the scaling factors
* per-tile and quantization of the row and/or columnwise tiles. No precalculated
* absolute max is required. The scaling factors are also rounded to powers of 2.
*/ */
/*! \brief Casts input tensor to FP8/MXFP8. /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * The type of quantized tensor in the output depends on the scaling mode of the output
* the block quantization (MXFP8) of the specified shape of the block will be used. * tensor. See file level comments.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor. * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor. * based on the value of the 'noop' tensor.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * The type of quantized tensor in the output depends on the scaling mode of the output
* the block quantization (MXFP8) of the specified shape of the block will be used. * tensor. See file level comments.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor. * \param[in,out] output Output quantized tensor.
* \param[out] noop Noop tensor. * \param[out] noop Noop tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
......
...@@ -80,8 +80,14 @@ enum NVTEScalingMode { ...@@ -80,8 +80,14 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING = 2, /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles,
NVTE_NO_SCALING = 3 which each yield a scale. The block_scaling_dim property of the quantizer
selects the granularity.
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
NVTE_INVALID_SCALING = 4,
NVTE_NO_SCALING = 5
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
......
...@@ -152,7 +152,8 @@ namespace { ...@@ -152,7 +152,8 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales, const float max_fp8, const bool force_pow_2_scales,
const float epsilon) { const float epsilon) {
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
std::numeric_limits<float>::max());
} }
} // namespace } // namespace
......
...@@ -7,19 +7,21 @@ ...@@ -7,19 +7,21 @@
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include <limits> #include "common/common.h"
namespace transformer_engine { namespace transformer_engine {
__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8,
bool force_pow_2_scales, float epsilon) { bool force_pow_2_scales, float epsilon,
float value_for_inf) {
// NOTE: NAN amax evaluates false for <, handled further down.
if (amax < epsilon) { if (amax < epsilon) {
amax = epsilon; amax = epsilon;
} }
float scale = 1.f; float scale = 1.f;
if (isinf(amax) || amax == 0.f) { if (isinf(amax) || amax == 0.f || isnan(amax)) {
return scale; return scale;
} }
...@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f ...@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
// the scale is not representable in FP32. // the scale is not representable in FP32.
if (isinf(scale)) { if (isinf(scale)) {
// use fp32 max to represent the scale // use fp32 max to represent the scale
scale = std::numeric_limits<float>::max(); scale = value_for_inf;
} }
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) { if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale); uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000; scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error. // If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0); __builtin_assume(scale_bits != 0 || scale == 0.0);
__builtin_assume(scale_bits != 0x80000000); __builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits); scale = *reinterpret_cast<float *>(&scale_bits);
} }
...@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f ...@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
return scale; return scale;
} }
// Calculate the quantization scale for an individual data element
// given the amax(abs(tile)) value for a given quantization tile.
//
//
// Arguments:
// IType: data type of the tensor being quantized (float or bf16)
// OType: quantized data type (e4m3 or e5m2)
// amax: The evaluation of amax(abs(tile)) for the quantization tile.
// eps: An epsilon used as a floor for amax.
// pow_2_scaling: Whether to force the scale to be a power of 2.
template <typename IType, typename OType>
__device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps,
const float pow_2_scaling) {
constexpr float fp8_max = TypeInfo<OType>::max_finite_value;
// NOTE: We're relying on compute_scale_from_amax to have behavior where it
// clips the mantissa of the max_finite_value if power of 2 scaling applies.
constexpr float value_for_inf = TypeInfo<IType>::max_finite_value;
return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf);
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
...@@ -215,48 +215,14 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -215,48 +215,14 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
NVTEShape ret;
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) { const std::vector<size_t> &rowwise_shape = t.rowwise_shape_ref();
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
}
ret.data = shape_cache.data();
ret.ndim = shape_cache.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}
NVTEShape ret;
ret.data = rowwise_shape.data();
ret.ndim = rowwise_shape.size();
return ret; return ret;
} }
......
...@@ -23,6 +23,18 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType ...@@ -23,6 +23,18 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
cudaStream_t stream); cudaStream_t stream);
void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment