Unverified Commit 77fa1e59 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] Enabling Per-Tensor Current Scaling Recipe (#1471)



* check in per-tensor current scaling full recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

setup basics of current scaling quantizer in python level
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

add test case for current scaling dequantize
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

finish linear layer fwd bwd test, determined error with bf16
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

achieved zero tolerance for Linear by specify gemm use_split_accumulator config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

enable layernormlinear with current scaling, pass bitwise test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

refactor test case code
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

make current scaling quantizers distrbuted, pass distributed linear&layernormlinear tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

bug fix: use cached fp8 recipe in backward
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

fix layernorm_mlp with current scaling, fix activation_helper with current scaling
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

support detailed numerical settings from recipe to quantization kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

resolving MR comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

recipe naming
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, remove IS_CURRENT_SCALING template from kernels
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, make current scaling c++ test cases
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add current scaling to test_numerics.py, skip act recomp and grouped linear
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmark for quantizer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmarks for linear layer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix, typo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more mr comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* avoid potential race condition by not using from_blob to construct amax tensor in C++
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Debug linter warnings and license check
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug import error in FP8 tensor test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug compilation error with CUDA 12.1 for Turing
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, fix activation cast fusion
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments, add NVTEQuantizationParams for compute scale
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove is_current_scaling check totally from common folder
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove benchmarks, will contribute in another repo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* adjust cs default recipe config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust comments in test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Remove current scaling mode from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor current-scaling-specific logic in core C++ lib

Move amax and scale update functions out of casting functions, and put into dedicated current-scaling source file. Add general API for accessing quantization config object.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing header in C++ tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable test config with FP8 transpose on Blackwell
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error in C++ test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 2a95efd3
...@@ -38,3 +38,4 @@ downloads/ ...@@ -38,3 +38,4 @@ downloads/
.pytest_cache/ .pytest_cache/
compile_commands.json compile_commands.json
.nfs .nfs
tensor_dumps/
\ No newline at end of file
File mode changed from 100644 to 100755
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
add_executable(test_operator add_executable(test_operator
test_cast.cu test_cast.cu
test_cast_current_scaling.cu
test_cast_dbias.cu test_cast_dbias.cu
test_cast_dbias_dgelu.cu test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu test_cast_gated_swiglu.cu
...@@ -13,6 +14,7 @@ add_executable(test_operator ...@@ -13,6 +14,7 @@ add_executable(test_operator
test_dequantize_mxfp8.cu test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
test_cast_transpose.cu test_cast_transpose.cu
test_cast_transpose_current_scaling.cu
test_cast_transpose_dbias.cu test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu test_cast_transpose_dgeglu.cu
......
...@@ -35,6 +35,8 @@ void compute_ref(const InputType *data, OutputType *output_c, ...@@ -35,6 +35,8 @@ void compute_ref(const InputType *data, OutputType *output_c,
*amax = current_max; *amax = current_max;
} }
// delayed tensor scaling test
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) { void performTest(const std::vector<size_t>& shape) {
using namespace test; using namespace test;
...@@ -55,6 +57,7 @@ void performTest(const std::vector<size_t>& shape) { ...@@ -55,6 +57,7 @@ void performTest(const std::vector<size_t>& shape) {
nvte_quantize(input.data(), output_c.data(), 0); nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax; float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(), compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, &ref_amax, output_c.scale()); full_size, &ref_amax, output_c.scale());
...@@ -105,6 +108,7 @@ TEST_P(CastTestSuite, TestCast) { ...@@ -105,6 +108,7 @@ TEST_P(CastTestSuite, TestCast) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// delayed tensor scaling
performTest<InputType, OutputType>(size); performTest<InputType, OutputType>(size);
); );
); );
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/recipe.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
}
template <typename InputType, typename OutputType>
void compute_amax_scale_ref(const InputType *data,
const size_t size,
float *amax_ptr, float *scale_ptr, float* scale_inv_ptr,
float max_fp8, float epsilon) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
}
*amax_ptr = current_max;
// compute scale from amax
float clamp_amax = current_max;
if (current_max <= epsilon){
clamp_amax = epsilon;
}
float scale = 1.f;
float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
return;
}
// use ieee_div in CPU
scale = max_fp8 / clamp_amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
scale_inv = 1.0f / scale;
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
}
// current tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
bool is_out_fp8 = isFp8Type(otype);
// find out max fp8 value
float max_fp8;
if (is_out_fp8){
switch (otype) {
case DType::kFloat8E5M2: {
max_fp8 = Quantized_Limits<fp8e5m2>::max();
} break;
case DType::kFloat8E4M3: {
max_fp8 = Quantized_Limits<fp8e4m3>::max();
} break;
default:
NVTE_ERROR("Invalid type.");
}
}
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype, true, false);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
// compute amax
float amax_to_check = 0.0f;
if (is_out_fp8){
nvte_compute_amax(input.data(), output_c.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(output_c.data(), config, 0);
// avoid atomic amax update in cuda cast kernels because of current per-tensor scaling
amax_to_check = output_c.amax();
output_c.set_tensor_amax_nullptr();
}
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
float ref_scale;
float ref_scale_inv;
if (is_out_fp8){
compute_amax_scale_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
full_size, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f);
}
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f );
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32);
compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32);
compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32);
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, 0.0f, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastCSTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastCSTestSuite, TestCastCS) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// current tensor scaling
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastCSTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastCSTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
...@@ -38,6 +38,8 @@ void compute_ref(const InputType *data, OutputType *output_c, OutputType *output ...@@ -38,6 +38,8 @@ void compute_ref(const InputType *data, OutputType *output_c, OutputType *output
*amax = current_max; *amax = current_max;
} }
// delayed tensor scaling test
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) { void performTest(const size_t N, const size_t H) {
using namespace test; using namespace test;
...@@ -75,6 +77,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -75,6 +77,7 @@ void performTest(const size_t N, const size_t H) {
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
} }
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288}, std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024}, {768, 1024},
{256, 65536}, {256, 65536},
...@@ -101,6 +104,7 @@ TEST_P(CTTestSuite, TestCastTranspose) { ...@@ -101,6 +104,7 @@ TEST_P(CTTestSuite, TestCastTranspose) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// delayed tensor scaling
performTest<InputType, OutputType>(size.first, size.second); performTest<InputType, OutputType>(size.first, size.second);
); );
); );
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/recipe.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t,
const size_t N, const size_t H,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i * H + j] = OutputType(scale * current);
output_t[j * N + i] = OutputType(scale * current);
}
}
}
template <typename InputType, typename OutputType>
void compute_amax_scale_ref(const InputType *data,
const size_t N, const size_t H,
float *amax_ptr, float *scale_ptr, float* scale_inv_ptr,
float max_fp8, float epsilon) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
}
}
*amax_ptr = current_max;
// compute scale from amax
float clamp_amax = current_max;
if (current_max <= epsilon){
clamp_amax = epsilon;
}
float scale = 1.f;
float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
return;
}
// use ieee_div in CPU
scale = max_fp8 / clamp_amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
scale_inv = 1.0f / scale;
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
}
// current tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
bool is_out_fp8 = isFp8Type(otype);
// find out max fp8 value
float max_fp8;
if (is_out_fp8){
switch (otype) {
case DType::kFloat8E5M2: {
max_fp8 = Quantized_Limits<fp8e5m2>::max();
} break;
case DType::kFloat8E4M3: {
max_fp8 = Quantized_Limits<fp8e4m3>::max();
} break;
default:
NVTE_ERROR("Invalid type.");
}
}
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype, true, true);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input);
// compute amax
float amax_to_check = 0.0f;
if (is_out_fp8){
nvte_compute_amax(input.data(), output.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(output.data(), config, 0);
// avoid atomic amax update in cuda cast kernels because of current per-tensor scaling
amax_to_check = output.amax();
output.set_tensor_amax_nullptr();
}
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
float ref_scale;
float ref_scale_inv;
if (is_out_fp8){
compute_amax_scale_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
N, H, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f);
}
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, nullptr,
is_out_fp8 ? output.scale() : 1.0f );
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32);
compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32);
compareResults("scale", output.scale(), ref_scale, 0.0f, rtol_fp32);
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32);
compareResults("scale_inv_columnwise", output.columnwise_cpu_scale_inv_ptr<float>()[0], ref_scale_inv, 0.0f, rtol_fp32);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, 0.0f, rtol);
compareResults("output_t", output, ref_output_t.get(), false, 0.0f, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8},
{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123
} // namespace
class CTCSTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTCSTestSuite, TestCastTransposeCS) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// current tensor scaling
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTCSTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTCSTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
...@@ -103,10 +103,6 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -103,10 +103,6 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
} }
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
...@@ -233,7 +229,7 @@ Tensor::Tensor(const std::string& name, ...@@ -233,7 +229,7 @@ Tensor::Tensor(const std::string& name,
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) { if (isFp8Type(type)) {
if (is_tensor_scaling(scaling_mode)) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float)); cudaMemset(amax, 0, sizeof(float));
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
...@@ -296,11 +292,13 @@ void Tensor::to_cpu() const { ...@@ -296,11 +292,13 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
if (is_tensor_scaling(tensor_.scaling_mode())) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(), cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(), tensor_.amax(),
sizeof(float), sizeof(float),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
}
cudaMemcpy(scale_cpu_data_.get(), cudaMemcpy(scale_cpu_data_.get(),
tensor_.scale(), tensor_.scale(),
sizeof(float), sizeof(float),
...@@ -336,9 +334,11 @@ void Tensor::from_cpu() const { ...@@ -336,9 +334,11 @@ void Tensor::from_cpu() const {
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
if (is_tensor_scaling(tensor_.scaling_mode())) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
...@@ -361,7 +361,7 @@ void Tensor::from_cpu() const { ...@@ -361,7 +361,7 @@ void Tensor::from_cpu() const {
void Tensor::set_scale(float scale) { void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
NVTE_CHECK(scale_cpu_data_); NVTE_CHECK(scale_cpu_data_);
if (is_tensor_scaling(tensor_.scaling_mode())) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale; *scale_cpu_data_ = scale;
from_cpu(); from_cpu();
} }
......
...@@ -256,6 +256,10 @@ class Tensor { ...@@ -256,6 +256,10 @@ class Tensor {
return columnwise_; return columnwise_;
} }
void set_tensor_amax_nullptr(){
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}
void to_cpu() const; void to_cpu() const;
void from_cpu() const; void from_cpu() const;
void set_scale(float scale); void set_scale(float scale);
......
...@@ -14,13 +14,15 @@ import transformer_engine.pytorch as te ...@@ -14,13 +14,15 @@ import transformer_engine.pytorch as te
import torch import torch
from torch import nn from torch import nn
import torch.distributed as dist import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
MXFP8BlockScaling, MXFP8BlockScaling,
DelayedScaling, DelayedScaling,
Float8CurrentScaling,
Format, Format,
Recipe, Recipe,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
SEQ_LEN, BATCH_SIZE = 16, 16 SEQ_LEN, BATCH_SIZE = 16, 16
...@@ -45,6 +47,8 @@ def quantization_recipe() -> Recipe: ...@@ -45,6 +47,8 @@ def quantization_recipe() -> Recipe:
) )
if QUANTIZATION == "mxfp8": if QUANTIZATION == "mxfp8":
return MXFP8BlockScaling() return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling()
return te.fp8.get_default_fp8_recipe() return te.fp8.get_default_fp8_recipe()
...@@ -88,6 +92,7 @@ def main(argv=None, namespace=None): ...@@ -88,6 +92,7 @@ def main(argv=None, namespace=None):
HIDDEN_SIZE = 128 HIDDEN_SIZE = 128
test_dict = [ test_dict = [
test_quantizer,
test_linear, test_linear,
test_layernorm, test_layernorm,
test_layernorm_linear, test_layernorm_linear,
...@@ -152,7 +157,12 @@ def dist_print(msg, src=None, end="\n", error=False): ...@@ -152,7 +157,12 @@ def dist_print(msg, src=None, end="\n", error=False):
def _get_tolerances(dtype): def _get_tolerances(dtype):
if QUANTIZATION is not None: # loose tolerances for fp8_cs because of sequence parallel & amax reduction
# so that each rank has a different scale_inv for computing Y when we have
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625} return {"rtol": 0.125, "atol": 0.0625}
if dtype == torch.float16: if dtype == torch.float16:
...@@ -293,6 +303,98 @@ def _alloc_main_grad(model_single_node, model_distributed): ...@@ -293,6 +303,98 @@ def _alloc_main_grad(model_single_node, model_distributed):
param.main_grad = torch.zeros_like(param, dtype=torch.float32) param.main_grad = torch.zeros_like(param, dtype=torch.float32)
###############################################
# Quantizer #
###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class(
fp8_dtype=fp8_dtype,
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
device=device,
with_amax_reduction=False,
)
return quantizer, quantizer_dist
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda())
return out
@run_distributed_test()
def _test_quantizer(input_dtype, fp8_dtype):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
fp8_dtype (tex.DType): The data type of the fp8.
"""
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE
# high precision input
x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
x_hp_cpu[M - 1, N - 1] = 1e4
# rank 0 takes the full copy and quantize with GPU 0 for verification
if WORLD_RANK == 0:
x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
# Create quantizers
quantizer, quantizer_dist = _construct_quantizer(
Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
)
# quantize the input
if WORLD_RANK == 0:
x_fp8_single = quantizer(x_hp_rank0)
# multi-GPU quantizer
x_fp8_dist = quantizer_dist(x_hp_local_rank)
# check scale_inv with zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(
x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0
)
def test_quantizer():
"""
Run quantizer tests with various configurations.
Currently only check fp8_cs because it needs to do amax reduction in the quantizer.
"""
# skip this test for other quantization schemes
if QUANTIZATION != "fp8_cs":
return
input_dtypes = [torch.float32, torch.bfloat16]
fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
for input_dtype in input_dtypes:
for fp8_dtype in fp8_dtypes:
_test_quantizer(input_dtype, fp8_dtype)
############################################ ############################################
# Linear # # Linear #
############################################ ############################################
...@@ -339,6 +441,11 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -339,6 +441,11 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
) )
input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed, dim=0).detach() input_single_node = _gather(input_distributed, dim=0).detach()
else: else:
input_distributed = input_single_node.clone() input_distributed = input_single_node.clone()
...@@ -501,6 +608,12 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -501,6 +608,12 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Duplicate input for sequence parallelism # Duplicate input for sequence parallelism
input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
# make the last element of the input a large value to test the amax reduction on purpose
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed).detach() input_single_node = _gather(input_distributed).detach()
else: else:
input_distributed = input_single_node.clone() input_distributed = input_single_node.clone()
...@@ -599,6 +712,12 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -599,6 +712,12 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Duplicate input for sequence parallelism # Duplicate input for sequence parallelism
input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
# make the last element of the input a large value to test the amax reduction on purpose
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed).detach() input_single_node = _gather(input_distributed).detach()
else: else:
input_distributed = input_single_node.clone() input_distributed = input_single_node.clone()
...@@ -651,6 +770,7 @@ def test_layernorm_mlp(): ...@@ -651,6 +770,7 @@ def test_layernorm_mlp():
{"return_bias": True}, {"return_bias": True},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for set_parallel_mode in [True]: for set_parallel_mode in [True]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -745,6 +865,7 @@ def test_transformer_layer(): ...@@ -745,6 +865,7 @@ def test_transformer_layer():
{"fuse_qkv_params": True}, {"fuse_qkv_params": True},
{"activation": "relu"}, {"activation": "relu"},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
_test_transformer_layer_parallel(sequence_parallel, **kwargs) _test_transformer_layer_parallel(sequence_parallel, **kwargs)
......
...@@ -48,10 +48,12 @@ def _run_test(quantization): ...@@ -48,10 +48,12 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8"]) @pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"])
def test_distributed(quantization): def test_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(fp8_available)
if quantization == "mxfp8" and not mxfp8_available: if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
_run_test(quantization) _run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch
# compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale)
# option2: when scale is inf, set scale to 1
scale = torch.where(scale == torch.inf, 1.0, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
def _multi_dim_transpose(tensor):
# Get the number of dimensions
dims = list(range(len(tensor.shape)))
if len(dims) <= 1:
return tensor
# circular shift of shapes
new_order = []
new_order.append(dims[-1])
for i in range(len(dims) - 1):
new_order.append(dims[i])
# Permute the tensor according to the new order
output_tensor = tensor.permute(new_order).contiguous()
return output_tensor
# current scaling reference quantization
def ref_per_tensor_cs_cast(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> torch.Tensor:
quant_dtype_torch = TE_DType_To_Torch[fp8_dtype]
scale, scale_inv, _ = _ref_compute_amax_scale(
tensor,
quant_dtype_torch,
amax_epsilon,
force_pow_2_scales,
)
qx = (tensor.float() * scale).to(quant_dtype_torch)
sx = scale_inv
qx_t = None
sx_t = None
if tensor.shape == torch.Size([]):
qx = qx.view([])
if return_transpose:
qx_t = _multi_dim_transpose(qx)
sx_t = sx
return qx, sx, qx_t, sx_t
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import os
import torch
import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
class GetRecipes:
@staticmethod
def none():
return None
@staticmethod
def fp8_per_tensor_current_scaling_default():
# return default configs
return Float8CurrentScaling()
# base class for validating current_scaling x linear layer
class TestFP8RecipeLinearBase:
@staticmethod
def _prepare_data(
batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32
):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda")
bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None
gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda")
return x, w, bias, gradient
@staticmethod
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad))
return out
@staticmethod
def _gather_tensor(local, world_size, tp_group, concat_dim):
out_list = [torch.zeros_like(local) for _ in range(world_size)]
torch.distributed.all_gather(out_list, local, tp_group)
return torch.cat(out_list, dim=concat_dim)
@staticmethod
def _all_reduce_tensor(local, world_size, tp_group):
if world_size == 1:
return local
handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False)
return local
@staticmethod
def _get_sum_abs_error(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _get_mean_abs_relative_error(a, b):
return torch.mean(torch.abs((a - b) / b))
@staticmethod
def _load_golden_tensor_values(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias):
recipe = get_recipe()
batch_size, hidden_size, out_size = dims
fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
"y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
}
if not use_bias:
expected_tensor_names.pop("bgrad")
# Check if all expected tensors are in the tensor dumps directory
tensor_map = {}
for tensor_key, tensor_name in expected_tensor_names.items():
tensor_path = dump_dir / tensor_name
if not os.path.exists(tensor_path):
print(f"Missing tensor: {tensor_name}")
return None
# Load the tensor
tensor_map[tensor_key] = torch.load(tensor_path)
return tensor_map
@classmethod
def run_linear_preprocess_parallel(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_size=1,
rank=0,
):
if tp_size > 1:
if parallel_mode == "column":
# split w in N dim, which should be axis 0
w = cls._shard_tensor(w, tp_size, 0)[rank]
bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
# split gradient in N dim, which should be axis 1
gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
if sequence_parallel:
# split x in M dim, which should be axis 0
x = cls._shard_tensor(x, tp_size, 0)[rank]
# row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
if parallel_mode == "row":
# split x in K dim, which should be axis 1
x = cls._shard_tensor(x, tp_size, 1)[rank]
# split w in K dim, which should be axis 1
w = cls._shard_tensor(w, tp_size, 1)[rank]
if sequence_parallel:
# split gradient in M dim, which should be axis 0
gradient = cls._shard_tensor(gradient, tp_size, 0)[rank]
return x, w, bias, gradient
@classmethod
def run_linear_postprocess_parallel(
cls,
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
):
if tp_size > 1:
if parallel_mode == "column":
# gather y_q in N dim, which should be axis 1
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1)
# gather wgrad in N dim, which should be axis 0
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0)
# gather bgrad in N dim, which should be axis 0
bgrad = (
cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None
)
if sequence_parallel:
# gather dgrad in M dim, which should be axis 0
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0)
if parallel_mode == "row":
# gather dgrad in K dim, which should be axis 1
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1)
# gather wgrad in K dim, which should be axis 1
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1)
if sequence_parallel:
# gather y_q in M dim, which should be axis 0
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0)
# we need to sum bias gradient when using TP + SP
bgrad = (
cls._all_reduce_tensor(bgrad, tp_size, tp_group)
if bgrad is not None
else None
)
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_one_step(
cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False
):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
if isinstance(layer, te.Linear):
# Kitchen Linear
y_q = layer.forward(x, is_first_microbatch=is_first_microbatch)
else:
# the default torch.nn.Linear
y_q = layer(x)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
bgrad = (
layer._parameters["bias"].grad
if layer._parameters.get("bias", None) is not None
else None
)
assert "weight" in layer._parameters
if fuse_wgrad_accumulation:
wgrad = layer._parameters["weight"].main_grad
assert layer._parameters["weight"].grad is None
else:
wgrad = layer._parameters["weight"].grad
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls,
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation=False,
):
"""
Run multiple steps of linear layer and collect results.
"""
y_q_list, dgrad_list, wgrad_list = [], [], []
bgrad_list = [] if layer._parameters.get("bias", None) is not None else None
for i in range(run_num_steps):
x_i = (x + i).clone().detach().requires_grad_(True)
# run_linear_one_step
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(
layer,
x_i,
gradient,
is_first_microbatch=(i == 0) if enable_weight_cache else None,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
# Collect results
y_q_list.append(y_q.detach().clone())
dgrad_list.append(dgrad.detach().clone())
wgrad_list.append(wgrad.detach().clone())
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
@classmethod
def run_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
fuse_wgrad_accumulation=False,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = te.Linear(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
layer = layer.to("cuda")
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
if fuse_wgrad_accumulation:
assert (
run_num_steps > 1
), "Fused weight gradient accumulation requires run_num_steps > 1"
layer.weight.main_grad = torch.zeros_like(layer.weight)
# Run one step or multiple steps
if run_num_steps == 1:
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
else:
y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps(
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation,
)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, dgrad, wgrad, bgrad
def compare_recipe(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed,
dtype,
y_error=0.0,
dgrad_error=0.0,
wgrad_error=0.0,
bgrad_error=0.0,
recipe1_golden_tensors=None,
recipe2_golden_tensors=None,
):
x, w, bias, gradient = self._prepare_data(
batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
# Compare results (mean abs relative error)
assert (
self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error
), "y and y_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error
), "dgrad and dgrad_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error
), "wgrad and wgrad_ref has too large mean abs relative error"
if use_bias:
assert (
self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error
), "bgrad and bgrad_ref has too large mean abs relative error"
# enforce zero tolerance check when we can find golden tensor value dump
if recipe2_golden_tensors is not None:
torch.testing.assert_close(
y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0
)
torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0)
torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0)
if use_bias:
torch.testing.assert_close(
bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0
)
class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
@staticmethod
def _check_golden_tensor_dumps(
dump_dir, get_recipe, dims, input_dtype, use_bias, normalization
):
recipe = get_recipe()
batch_size, hidden_size, out_size = dims
fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
"y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
}
if not use_bias:
expected_tensor_names.pop("bgrad")
# Check if all expected tensors are in the tensor dumps directory
tensor_map = {}
for tensor_key, tensor_name in expected_tensor_names.items():
tensor_path = dump_dir / tensor_name
if not os.path.exists(tensor_path):
print(f"Missing tensor: {tensor_name}")
return None
# Load the tensor
tensor_map[tensor_key] = torch.load(tensor_path)
return tensor_map
@classmethod
def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
parameters = layer._parameters
# bias and weight gradients
bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None
assert "weight" in parameters
wgrad = parameters["weight"].grad
return y_q, ln_out, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False
):
# raise error, no test case for multiple steps for now
raise NotImplementedError("LayerNormLinear does not support test multiple steps for now")
@classmethod
def run_layernorm_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
LayerNormLinearClass=te.LayerNormLinear,
normalization="LayerNorm",
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = LayerNormLinearClass(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
normalization=normalization,
return_layernorm_output=True,
)
layer = layer.to("cuda")
# Copy weights
# kitchen_linear has different parameter names
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
# Run one step
y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, ln_out, dgrad, wgrad, bgrad
def compare_recipe(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed,
dtype,
y_error=0.0,
ln_out_error=0.0,
dgrad_error=0.0,
wgrad_error=0.0,
bgrad_error=0.0,
normalization="LayerNorm",
LayerNormLinearClass1=te.LayerNormLinear,
LayerNormLinearClass2=te.LayerNormLinear,
recipe1_golden_tensors=None,
recipe2_golden_tensors=None,
):
x, w, bias, gradient = self._prepare_data(
batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass1,
)
else:
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass1,
)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass2,
)
else:
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass2,
)
# Compare results (mean abs relative error)
assert (
self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error
), "y and y_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(ln_out, ln_out_ref).item() < ln_out_error
), "ln_out and ln_out_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error
), "dgrad and dgrad_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error
), "wgrad and wgrad_ref has too large mean abs relative error"
if use_bias:
assert (
self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error
), "bgrad and bgrad_ref has too large mean abs relative error"
# enforce zero tolerance check when we can find golden tensor value dump
if recipe2_golden_tensors is not None:
torch.testing.assert_close(
y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0
)
torch.testing.assert_close(ln_out, recipe2_golden_tensors["ln_out"], atol=0.0, rtol=0.0)
torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0)
torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0)
if use_bias:
torch.testing.assert_close(
bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0
)
# FP8 per tesnor current scaling
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
...@@ -12,9 +12,17 @@ import torch ...@@ -12,9 +12,17 @@ import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
# PyTorch tensor dtypes # PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes # TE FP8 dtypes
...@@ -42,6 +50,7 @@ DimsType = Union[Iterable[int], int] ...@@ -42,6 +50,7 @@ DimsType = Union[Iterable[int], int]
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# delayed scaling
def to_float8( def to_float8(
tensor: torch.Tensor, tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
...@@ -56,6 +65,29 @@ def to_float8( ...@@ -56,6 +65,29 @@ def to_float8(
return quantizer(tensor.cuda()) return quantizer(tensor.cuda())
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
tensor = tensor.cuda()
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype,
device=tensor.device,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
quantizer.set_usage(rowwise=True, columnwise=False)
return quantizer(tensor)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor: class TestFloat8Tensor:
...@@ -310,3 +342,89 @@ class TestFloat8Tensor: ...@@ -310,3 +342,89 @@ class TestFloat8Tensor:
assert x.size() == y.size() assert x.size() == y.size()
assert x.dtype == y.dtype assert x.dtype == y.dtype
assert x.device == y.device assert x.device == y.device
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestCurrentScalingFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize(
"dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]]
)
@pytest.mark.parametrize("return_transpose", [True, False], ids=str)
@pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str)
@pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str)
def test_quantize(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
dims: DimsType,
return_transpose: bool,
force_pow_2_scales: bool,
amax_epsilon: float,
) -> None:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data
device = "cuda"
x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1
# Cast to FP8 and back
x_fp8 = to_float8_CS(
x_hp,
fp8_dtype=fp8_dtype,
return_transpose=return_transpose,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
# get reference implementation of current scaling
x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast(
x_hp,
fp8_dtype=fp8_dtype,
return_transpose=return_transpose,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0)
torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(
x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random high precision data
device = "cuda"
x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype)
x_fp8_dequantized = x_fp8.dequantize()
# Check results
torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
...@@ -100,6 +100,7 @@ mask_types = ["causal", "no_mask"] ...@@ -100,6 +100,7 @@ mask_types = ["causal", "no_mask"]
fp8_recipes = [ fp8_recipes = [
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(), recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
] ]
...@@ -670,6 +671,8 @@ def test_gpt_full_activation_recompute( ...@@ -670,6 +671,8 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model] config = model_configs[model]
...@@ -1482,6 +1485,8 @@ def test_grouped_linear_accuracy( ...@@ -1482,6 +1485,8 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1675,6 +1680,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1675,6 +1680,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
......
...@@ -23,6 +23,7 @@ import transformer_engine_torch as tex ...@@ -23,6 +23,7 @@ import transformer_engine_torch as tex
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# FP8 per tensor delayed scaling
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe: class TestFP8Recipe:
......
...@@ -86,6 +86,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -86,6 +86,7 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
......
...@@ -29,6 +29,18 @@ ...@@ -29,6 +29,18 @@
namespace transformer_engine { namespace transformer_engine {
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); }
inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) { inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries"); end, " in a vector with ", shape.size(), " entries");
...@@ -132,7 +144,7 @@ struct Tensor { ...@@ -132,7 +144,7 @@ struct Tensor {
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape; const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1; if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (is_tensor_scaling(scaling_mode)) {
return product(data_shape, 1, data_shape.size()); return product(data_shape, 1, data_shape.size());
} else { } else {
return product(data_shape, 0, data_shape.size() - 1); return product(data_shape, 0, data_shape.size() - 1);
...@@ -152,7 +164,7 @@ struct Tensor { ...@@ -152,7 +164,7 @@ struct Tensor {
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape; const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1; if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (is_tensor_scaling(scaling_mode)) {
return data_shape.front(); return data_shape.front();
} else { } else {
return data_shape.back(); return data_shape.back();
...@@ -164,6 +176,16 @@ struct Tensor { ...@@ -164,6 +176,16 @@ struct Tensor {
} }
}; };
struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
sizeof(float) // amax_epsilon
};
};
template <typename T> template <typename T>
constexpr T DIVUP(const T &x, const T &y) { constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
...@@ -396,6 +418,15 @@ struct TypeInfo { ...@@ -396,6 +418,15 @@ struct TypeInfo {
} \ } \
} }
#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \
if (CONDITION) { \
constexpr bool FLAG = true; \
{ __VA_ARGS__ } \
} else { \
constexpr bool FLAG = false; \
{ __VA_ARGS__ } \
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline int log2_ceil(int value) { inline int log2_ceil(int value) {
...@@ -449,20 +480,6 @@ bool is_fp8_dtype(const DType t); ...@@ -449,20 +480,6 @@ bool is_fp8_dtype(const DType t);
std::string to_string(const DType type); std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &type); std::string to_string(const NVTEScalingMode &type);
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_block_scaling(const NVTEScalingMode &mode) {
return mode != NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return is_tensor_scaling(mode);
}
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
/*! \brief Update a tensor's FP8 scale-inverse /*! \brief Update a tensor's FP8 scale-inverse
* *
* The FP8 scale-inverse (dequantization scaling factor) is updated * The FP8 scale-inverse (dequantization scaling factor) is updated
......
...@@ -73,6 +73,29 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( ...@@ -73,6 +73,29 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype, std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
float margin, cudaStream_t stream); float margin, cudaStream_t stream);
/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
* Options are primarily intended for FP8 current-scaling recipes.
*
* \param[in,out] output FP8 tensor with per-tensor scaling.
* \param[in] config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -68,11 +68,14 @@ enum NVTETensorParam { ...@@ -68,11 +68,14 @@ enum NVTETensorParam {
}; };
/*! \enum NVTEScalingMode /*! \enum NVTEScalingMode
* \brief Granularity of scaling: * \brief Tensor data format.
*/ */
enum NVTEScalingMode { enum NVTEScalingMode {
/*! Single scale per tensor, computed in delayed manner. /*! Either an unquantized tensor or an FP8 tensor with per-tensor scaling
Used also for high precision data, without scaling */ *
* Not necessary used for delayed tensor scaling. The unintuitive
* name reflects legacy usage.
*/
NVTE_DELAYED_TENSOR_SCALING = 0, NVTE_DELAYED_TENSOR_SCALING = 0,
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
...@@ -266,6 +269,57 @@ void nvte_tensor_pack_create(NVTETensorPack *pack); ...@@ -266,6 +269,57 @@ void nvte_tensor_pack_create(NVTETensorPack *pack);
*/ */
void nvte_tensor_pack_destroy(NVTETensorPack *pack); void nvte_tensor_pack_destroy(NVTETensorPack *pack);
/*! \brief Configuration for tensor quantization. */
typedef void *NVTEQuantizationConfig;
/*! \enum NVTEQuantizationConfigAttribute
* \brief Type of option for tensor quantization.
*/
enum NVTEQuantizationConfigAttribute {
/*! Whether to force power of 2 scales */
kNVTEQuantizationConfigForcePow2Scales = 0,
/*! Small value to add to amax for numerical stability */
kNVTEQuantizationConfigAmaxEpsilon = 1,
kNVTEQuantizationConfigNumAttributes
};
/*! \brief Create a new quantization config.
* \return A new quantization config.
*/
NVTEQuantizationConfig nvte_create_quantization_config();
/*! \brief Query an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes);
/*! \brief Destroy a quantization config.
*
* \param[in] config Config to be destroyed.
*/
void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
...@@ -610,6 +664,58 @@ class TensorWrapper { ...@@ -610,6 +664,58 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
}; };
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
*/
class QuantizationConfigWrapper {
public:
QuantizationConfigWrapper() : config_{nvte_create_quantization_config()} {}
QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete;
QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete;
QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_quantization_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~QuantizationConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_quantization_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEQuantizationConfig.
*
* \return NVTEQuantizationConfig held by this QuantizationConfigWrapper.
*/
operator NVTEQuantizationConfig() const noexcept { return config_; }
/*! \brief Set whether to force power of 2 scales */
void set_force_pow_2_scales(bool force_pow_2_scales) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales,
&force_pow_2_scales, sizeof(bool));
}
/*! \brief Set small value to add to amax */
void set_amax_epsilon(float amax_epsilon) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon,
&amax_epsilon, sizeof(float));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
};
} // namespace transformer_engine } // namespace transformer_engine
#endif // __cplusplus #endif // __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment