Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -15,7 +15,7 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
......@@ -64,26 +64,23 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype);
Tensor input("input", {N, H}, itype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor output("output", {N, H}, otype, true, true);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias(input.cpu_dptr<IType>(),
output_c.scale(),
compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr<IType>(),
output.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
......@@ -92,22 +89,20 @@ void performTest(const size_t N, const size_t H) {
Tensor workspace;
nvte_cast_transpose_dbias(input.data(),
output_c.data(),
output_t.data(),
dbias.data(),
workspace.data(),
0);
nvte_quantize_dbias(input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor(workspace.shape(), workspace.dtype());
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_cast_transpose_dbias(input.data(),
output_c.data(),
output_t.data(),
dbias.data(),
workspace.data(),
0);
nvte_quantize_dbias(input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -115,17 +110,17 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias);
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
......
......@@ -75,29 +75,26 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype);
Tensor gelu_input({N, H}, itype);
Tensor input("input", {N, H}, itype);
Tensor gelu_input("gelu_input", {N, H}, itype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor output("output", {N, H}, otype, true, true);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&gelu_input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr<IType>(),
gelu_input.cpu_dptr<IType>(),
output_c.scale(),
compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
gelu_input.rowwise_cpu_dptr<IType>(),
output.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
......@@ -108,19 +105,17 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
output_c.data(),
output_t.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor(workspace.shape(), workspace.dtype());
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
output_c.data(),
output_t.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
......@@ -131,18 +126,18 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias);
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
......
......@@ -74,24 +74,22 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad({N, H}, itype);
Tensor input({N, H * 2}, itype);
Tensor output_c({N, H * 2}, otype);
Tensor output_t({H * 2, N}, otype);
Tensor grad("grad", {N, H}, itype);
Tensor input("input", {N, H * 2}, itype);
Tensor output("output", {N, H * 2}, otype, true, true);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0);
nvte_dgeglu_cast_transpose(grad.data(), input.data(), output.data(), 0);
CType ref_amax;
compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
output_c.scale(), ref_output_c.get(), ref_output_t.get(),
compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
output.scale(), ref_output_c.get(), ref_output_t.get(),
&ref_amax, N, H);
cudaDeviceSynchronize();
......@@ -100,14 +98,14 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400}, {4096, 2048}, {768, 2816},
......
......@@ -153,11 +153,11 @@ void performTest(
DType itype = TypeInfo<Type>::dtype;
Tensor data_in({ batches, heads, rows, cols }, itype);
Tensor softmax_out({ batches, heads, rows, cols }, itype);
Tensor softmax_in({ batches, heads, rows, cols }, itype);
Tensor grads_in({ batches, heads, rows, cols }, itype);
Tensor grads_out({ batches, heads, rows, cols }, itype);
Tensor data_in("data_in", { batches, heads, rows, cols }, itype);
Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype);
Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype);
Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype);
Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype);
const size_t elements_total = batches * heads * rows * cols;
std::unique_ptr<Type[]> softmax_out_ref = std::make_unique<Type[]>(elements_total);
......@@ -175,9 +175,9 @@ void performTest(
// Reference implementations
compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr<Type>(),
compute_fwd_ref(softmax_out_ref.get(), data_in.rowwise_cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr<Type>(), softmax_in.cpu_dptr<Type>(),
compute_bwd_ref(grads_out_ref.get(), grads_in.rowwise_cpu_dptr<Type>(), softmax_in.rowwise_cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
cudaDeviceSynchronize();
......@@ -187,8 +187,8 @@ void performTest(
if(itype == DType::kBFloat16) {
atol = 1e-3;
}
compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol);
compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol);
compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), true, atol, rtol);
compareResults("softmax_bwd", grads_out, grads_out_ref.get(), true, atol, rtol);
}
// [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor]
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <limits>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename InputType, typename OutputType>
void dequantize_block(const InputType* input,
OutputType* output,
fp8e8m0* scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols)
{
const fp8e8m0 biased_exponent = scales[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS);
const float elem_scale = block_scale;
// Dequantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float elt = static_cast<float>(input[idx]);
output[idx] = static_cast<OutputType>(elt * elem_scale);
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_x1(const InputType* input,
OutputType* output,
fp8e8m0* scales,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_x2(const InputType* input,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise)
{
compute_ref_x1<InputType, OutputType>(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType>(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
void generate_scales(fp8e8m0 * const scales_ref,
fp8e8m0 * const scales,
const size_t blocks_num,
std::mt19937& gen,
std::uniform_int_distribution<fp8e8m0> dis)
{
for (size_t i = 0; i < blocks_num; ++i) {
const fp8e8m0 val = dis(gen);
scales_ref[i] = val;
scales[i] = val;
}
}
template<typename InputType>
void generate_data(InputType * const data,
const size_t rows,
const size_t cols,
std::mt19937& gen,
std::uniform_real_distribution<>& dis,
std::uniform_real_distribution<>& dis_sign)
{
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
const bool is_negative = (dis_sign(gen) < 0.0);
double val = dis(gen);
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
}
}
}
template<typename InputType>
void fill_tensor_data(Tensor& input,
fp8e8m0 * const scales_rowwise,
fp8e8m0 * const scales_colwise,
const bool is_rowwise_scaling,
const bool is_colwise_scaling,
const size_t rows,
const size_t cols,
const size_t blocks_num_rowwise,
const size_t blocks_num_colwise)
{
const double minAbs = Numeric_Traits<InputType>::minNorm;
const double maxAbs = Numeric_Traits<InputType>::maxNorm;
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(minAbs, maxAbs);
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
std::uniform_int_distribution<fp8e8m0> int_dis(0, 255);
if (is_rowwise_scaling) {
generate_scales(scales_rowwise, input.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), blocks_num_rowwise, gen, int_dis);
generate_data(input.rowwise_cpu_dptr<InputType>(), rows, cols, gen, dis, dis_sign);
}
if (is_colwise_scaling) {
generate_scales(scales_colwise, input.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), blocks_num_colwise, gen, int_dis);
generate_data(input.columnwise_cpu_dptr<InputType>(), rows, cols, gen, dis, dis_sign);
}
input.from_cpu();
}
// Dequantize along single dimension (either row- or columnwise)
template <typename InputType, typename OutputType>
void performTest_x1(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;
const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);
const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise;
const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise;
const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise;
const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise;
Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
// Output data are written to the rowwise ptr regardless of the scaling direction
Tensor output("output", { rows, cols }, otype, true, false);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> scales = std::make_unique<fp8e8m0[]>(blocks_num);
fill_tensor_data<InputType>(input, scales.get(), scales.get(), rowwise, colwise, rows, cols,
blocks_num_rowwise, blocks_num_colwise);
nvte_dequantize(input.data(), output.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
InputType * data_ptr = rowwise
? input.rowwise_cpu_dptr<InputType>()
: input.columnwise_cpu_dptr<InputType>();
compute_ref_x1<InputType, OutputType>(data_ptr,
ref_output.get(),
scales.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), true, atol, rtol);
}
// Dequantize along single dimension (either row- or columnwise)
template <typename InputType, typename IntermediateType>
void performTest_quantize_then_dequantize(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
using EncodingType = fp32;
DType in_type = TypeInfo<InputType>::dtype;
DType intermed_type = TypeInfo<IntermediateType>::dtype;
DType out_type = TypeInfo<InputType>::dtype;
std::unique_ptr<InputType[]> input_cpu = std::make_unique<InputType[]>(rows * cols);
std::unique_ptr<IntermediateType[]> quantized_cpu = std::make_unique<IntermediateType[]>(rows * cols);
std::unique_ptr<InputType[]> output_cpu = std::make_unique<InputType[]>(rows * cols);
// input --> quantized --> output (dequantized)
// input == output
Tensor input("input", { rows, cols }, in_type);
Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
// Output data are written to the rowwise ptr regardless of the scaling direction
Tensor output("output", { rows, cols }, out_type, true, false);
// fillCase<EncodingType>(&input, InputsFillCase::minNorm_to_maxNorm);
fillCase<EncodingType>(&input, InputsFillCase::uniform);
const size_t copy_size = sizeof(InputType) * rows * cols;
cudaMemcpy(input_cpu.get(), input.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost);
nvte_quantize(input.data(), quantized.data(), 0);
cudaDeviceSynchronize();
const size_t copy_size_quantized = sizeof(IntermediateType) * rows * cols;
if (rowwise) {
cudaMemcpy(quantized_cpu.get(), quantized.rowwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost);
}
if (colwise) {
cudaMemcpy(quantized_cpu.get(), quantized.columnwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost);
}
nvte_dequantize(quantized.data(), output.data(), 0);
cudaDeviceSynchronize();
cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost);
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(intermed_type);
compareResults("Quantize-Dequantize", input, output_cpu.get(), true, atol, rtol);
}
// Dequantize along both dimensions (row- and columnwise)
template <typename InputType, typename OutputType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols)
{
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;
const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t scales_stride_colwise = blocks_X_colwise;
const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise;
const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise;
Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING);
Tensor output("output", { rows, cols }, otype);
std::unique_ptr<OutputType[]> ref_output_rowwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_colwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_num_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_num_colwise);
constexpr bool rowwise = true;
constexpr bool colwise = true;
fill_tensor_data<InputType>(input, ref_scales_rowwise.get(), ref_scales_colwise.get(),
rowwise, colwise, rows, cols, blocks_num_rowwise, blocks_num_colwise);
nvte_dequantize(input.data(), output.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> tensor_dims = {
{1, 16},
{16, 48},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
// {2048, 12288},
// {65536, 128},
// {16384, 1632},
// {16384, 6144},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
// {32, 32},
};
} // namespace
class DequantizeMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
bool>> {};
TEST_P(DequantizeMXFP8TestSuite, TestDequantizeMXFP8)
{
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const auto tensor_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const bool quantize_then_dequantize = std::get<4>(GetParam());
const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;
// Skip tests for dequantization along both dimensions
if (rowwise && colwise) {
GTEST_SKIP();
}
// Skip cases with invalid alignment
if (rowwise && tensor_size.second % 32 != 0) {
GTEST_SKIP();
}
if (colwise && tensor_size.first % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
if (quantize_then_dequantize) {
// Mind the order of the Output/Input template parameters
performTest_quantize_then_dequantize<OutputType, InputType>(
tensor_size.first, tensor_size.second, rowwise, colwise);
} else {
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType>(tensor_size.first, tensor_size.second,
rowwise, colwise);
} else {
performTest_x2<InputType, OutputType>(tensor_size.first, tensor_size.second,
block_size.first, block_size.second);
}
}
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DequantizeMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(tensor_dims),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(false)),
[](const testing::TestParamInfo<DequantizeMXFP8TestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
(std::get<4>(info.param) ? "QD" : "D");
return name;
}
);
......@@ -69,7 +69,7 @@ void performTest() {
const size_t num_tensors = tensor_dims.size();
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_c_list, output_t_list;
std::vector<Tensor> input_list, output_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
......@@ -81,25 +81,23 @@ void performTest() {
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
input_list.emplace_back(Tensor({ height, width }, itype));
output_c_list.emplace_back(Tensor({ height, width }, otype));
output_t_list.emplace_back(Tensor({ width, height }, otype));
input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype));
output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id),
{ height, width }, otype, true, true));
auto& input = input_list.back();
auto& output_c = output_c_list.back();
auto& output_t = output_t_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
setRandomScale(&output);
ref_input_list.emplace_back(height*width);
ref_output_c_list.emplace_back(height*width);
ref_output_t_list.emplace_back(width*height);
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
std::copy(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_scale_list[tensor_id] = output_c.scale();
ref_scale_list[tensor_id] = output.scale();
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
}
......@@ -115,8 +113,7 @@ void performTest() {
};
nvte_multi_cast_transpose(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_c_list).data(),
make_nvte_vector(output_t_list).data(),
make_nvte_vector(output_list).data(),
0);
// Reference implementation
......@@ -136,23 +133,23 @@ void performTest() {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax",
output_c_list[tensor_id].amax(),
output_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
1.f / output_c_list[tensor_id].scale(),
output_list[tensor_id].rowwise_scale_inv(),
1.f / output_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
output_c_list[tensor_id],
output_list[tensor_id],
ref_output_c_list[tensor_id].data(),
atol, rtol);
true, atol, rtol);
compareResults("output_t",
output_t_list[tensor_id],
output_list[tensor_id],
ref_output_t_list[tensor_id].data(),
atol, rtol);
false, atol, rtol);
}
}
......
......@@ -9,6 +9,7 @@
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
......@@ -84,8 +85,8 @@ void performTest() {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor({ height, width }, itype));
output_list.emplace_back(Tensor({ padded_height, width }, otype));
input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype));
output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype));
auto& input = input_list.back();
auto& output = output_list.back();
......@@ -95,8 +96,8 @@ void performTest() {
ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
std::copy(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
......@@ -134,6 +135,7 @@ void performTest() {
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
true,
atol, rtol);
}
}
......
......@@ -10,7 +10,6 @@
#include <iomanip>
#include <iostream>
#include <random>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
......@@ -176,6 +175,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
......@@ -187,16 +191,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return;
}
Tensor input({ N, H }, itype);
Tensor z({ N, H }, otype);
Tensor gamma({ H }, wtype);
Tensor beta({ H }, wtype);
Tensor mu({ N }, DType::kFloat32);
Tensor rsigma({ N }, DType::kFloat32);
Tensor dz({ N, H }, wtype);
Tensor dx({ N, H }, itype);
Tensor dgamma({ H }, wtype);
Tensor dbeta({ H }, wtype);
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor dz("dz", { N, H }, wtype);
Tensor dx("dx", { N, H }, itype);
Tensor dgamma("dgamma", { H }, wtype);
Tensor dbeta("dbeta", { H }, wtype);
Tensor workspace_fwd, workspace_bwd;
fillUniform(&input);
......@@ -226,7 +230,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype());
workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
......@@ -236,7 +240,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
dx.data(), dgamma.data(), dbeta.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype());
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
......@@ -246,7 +250,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype());
workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
......@@ -255,7 +259,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype());
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
......@@ -272,23 +276,24 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
mu.to_cpu();
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(norm_type, input.cpu_dptr<InputType>(), ref_mu.get(),
compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(norm_type, input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(),
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(),
ref_output.get(),
mu.cpu_dptr<float>(),
rsigma.cpu_dptr<float>(),
mu.rowwise_cpu_dptr<float>(),
rsigma.rowwise_cpu_dptr<float>(),
N, H,
&ref_amax,
ref_scale,
zero_centered_gamma,
use_cudnn);
compute_ref_backward(norm_type, dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(),
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma,
use_cudnn);
......@@ -301,25 +306,25 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats);
compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
auto [atol, rtol] = getTolerances(otype);
if (otype == DType::kFloat32) {
atol = 5e-7;
}
compareResults("output", z, ref_output.get(), atol, rtol);
compareResults("output", z, ref_output.get(), true, atol, rtol);
double atol_bwd = 5e-4;
double rtol_bwd = 5e-4;
compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd);
compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd);
compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
......@@ -357,24 +362,24 @@ TEST_P(NormTestSuite, TestNorm) {
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
NormTestSuite,
::testing::Combine(
::testing::Values(false), //TODO: enabling tests for cudnn backend
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
OperatorTest,
NormTestSuite,
::testing::Combine(
::testing::Values(true, false),
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
std::string name =
backend +
normToString.at(std::get<1>(info.param)) + "_" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param));
return name;
});
std::string name =
backend +
normToString.at(std::get<1>(info.param)) + "_" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <map>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
using fp8e8m0 = byte;
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RMSNorm"}
};
template <typename InputType, typename ScaleType, typename OutputType>
void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr,
size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){
const size_t block_size_Y = scaling_mode_x; // mind the mapping Y <-- x
const size_t block_size_X = scaling_mode_y; // and X <-- y
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X;
#pragma omp parallel for proc_bind(spread) schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X;
// TODO: padded SFs i.e. (4,128)
const float scale_inv = exp2f(static_cast<float>(scale_ptr[mx_scale_idx]) - FP32_EXPONENT_BIAS);
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float elem = static_cast<float>(input_ptr[idx]);
output_ptr[idx] = static_cast<float>(elem * scale_inv);
}
}
}
}
}
}
template <typename InputType, typename ScaleType>
void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
{
input.to_cpu();
auto scaling_mode = input.scaling_mode();
assert(input.rowwise_shape().ndim == 2);
assert(input.columnwise_shape().ndim == 2);
dequantize_1x_kernel(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_scale_inv_ptr<ScaleType>(),
output.rowwise_cpu_dptr<float>(),
input.rowwise_shape().data[0], input.rowwise_shape().data[1],
1, 32);
if (is_training)
dequantize_1x_kernel(input.columnwise_cpu_dptr<InputType>(),
input.columnwise_cpu_scale_inv_ptr<ScaleType>(),
output.columnwise_cpu_dptr<float>(),
input.columnwise_shape().data[0], input.columnwise_shape().data[1],
32, 1);
}
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
compute_t m;
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
OutputType* output,
const bool zero_centered_gamma){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1.0;
}
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = tmp;
}
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor workspace;
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&beta);
// Forward kernel
float epsilon = 1e-5;
if (norm_type == NormType::LayerNorm){
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
// Reference implementations
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_output = std::make_unique<float[]>(N * H);
compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
// use the GPU stats to tighten the tolerances
float *ref_mu_ptr, *ref_rsigma_ptr;
if (is_training){
mu.to_cpu();
rsigma.to_cpu();
ref_mu_ptr = mu.rowwise_cpu_dptr<float>();
ref_rsigma_ptr = rsigma.rowwise_cpu_dptr<float>();
} else {
ref_mu_ptr = ref_mu.get();
ref_rsigma_ptr = ref_rsigma.get();
}
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(),
ref_mu_ptr,
ref_rsigma_ptr,
N, H,
ref_output.get(),
zero_centered_gamma);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
if (is_training){
compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
}
float atol, rtol;
if (otype == DType::kFloat8E5M2){
atol = 1.25e-1;
rtol = 1.25e-1;
} else if (otype == DType::kFloat8E4M3){
if (itype == DType::kBFloat16){
atol = 7e-2;
rtol = 7e-2;
} else {
atol = 6.25e-2;
rtol = 6.25e-2;
}
}
compareResults("output_rowwise", dequantized_output, ref_output.get(), true, atol, rtol, false);
if (is_training)
compareResults("output_colwise", dequantized_output, ref_output.get(), false, atol, rtol, false);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{32, 32},
{768, 2304},
{2048, 12288},
};
std::vector<NormType> norms = {
NormType::LayerNorm,
NormType::RMSNorm
};
} // namespace
class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool, bool>> {};
TEST_P(MxNormTestSuite, TestMxNorm) {
using namespace transformer_engine;
using namespace test;
const NormType norm_type = std::get<0>(GetParam());
const DType input_type = std::get<1>(GetParam());
const DType output_type = std::get<2>(GetParam());
const auto size = std::get<3>(GetParam());
const bool zero_centered_gamma = std::get<4>(GetParam());
const bool is_training = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxNormTestSuite,
::testing::Combine(
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(true, false),
::testing::Values(true, false)),
[](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) {
std::string name = normToString.at(std::get<0>(info.param)) + "_" +
test::typeName(std::get<1>(info.param)) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
std::to_string(std::get<3>(info.param).first) + "X" +
std::to_string(std::get<3>(info.param).second) + "X" +
std::to_string(std::get<4>(info.param)) + "out" +
std::to_string(int(std::get<5>(info.param)) + 1) + "x";
return name;
});
......@@ -58,18 +58,18 @@ void performTestQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(&input);
setRandomScale(&output);
nvte_fp8_quantize(input.data(), output.data(), 0);
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
compute_ref_q<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
compute_ref_q<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output.get(),
N, &ref_amax, output.scale());
cudaDeviceSynchronize();
......@@ -79,7 +79,7 @@ void performTestQ(const size_t N) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol);
compareResults("output_q", output, ref_output.get(), true, atol, rtol);
}
template <typename InputType, typename OutputType>
......@@ -89,24 +89,24 @@ void performTestDQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(&input);
nvte_fp8_dequantize(input.data(), output.data(), 0);
nvte_dequantize(input.data(), output.data(), 0);
compute_ref_dq<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, input.scale_inv());
compute_ref_dq<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output.get(),
N, input.rowwise_scale_inv());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_dq", output, ref_output.get(), atol, rtol);
compareResults("output_dq", output, ref_output.get(), true, atol, rtol);
}
std::vector<size_t> qdq_test_cases = {2048* 12288,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstdint>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/swizzle.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
constexpr int MAT_TILE_DIM_M = 128;
constexpr int MAT_TILE_DIM_K = 128;
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K, bool row_scaling>
void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output,
const size_t M, const size_t K) {
constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4;
constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4;
constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
int tile_id_m = m / SF_TILE_DIM_M;
int tile_id_k = k / SF_TILE_DIM_K;
int m_in_tile = m % SF_TILE_DIM_M;
int k_in_tile = k % SF_TILE_DIM_K;
int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M;
int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile;
int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE;
int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile;
if constexpr(row_scaling)
h_output[out_index] = h_input[k + m * K];
else
h_output[out_index] = h_input[k * M + m];
}
}
}
void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;
int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
}
DType dtype = DType::kFloat8E4M3;
const size_t M = num_tiles_M * MAT_TILE_DIM_M;
const size_t K = num_tiles_K * MAT_TILE_DIM_K;
const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};
const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y};
std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0};
Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
fillUniform(&input);
std::unique_ptr<uint8_t[]> ref_output = std::make_unique<uint8_t[]>(scale_shape[0] * scale_shape[1]);
nvte_swizzle_scaling_factors(input.data(), output.data(), 0);
if (rowwise)
compute_ref_swizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0], scale_shape[1]);
else
compute_ref_swizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[1], scale_shape[0]);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
output.to_cpu();
if (rowwise) {
compareResults("output_swizzle", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
} else {
compareResults("output_swizzle", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
}
}
class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};
TEST_P(SwizzleTestSuite, TestSwizzle) {
using namespace transformer_engine;
using namespace test;
const auto num_tiles = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());
performTestSwizzle1D(num_tiles.first, num_tiles.second,
scaling_mode.first, scaling_mode.second,
transa);
}
namespace {
std::vector<std::pair<int, int>> num_tiles = {
{1, 1},
{1, 132},
{132, 1},
{65, 256},
{65, 257},
{65, 258},
{65, 259},
};
std::vector<std::pair<bool, bool>> scaling_mode = {
{true, false},
{false, true}
};
std::vector<bool> transa = {true, false};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(num_tiles),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<SwizzleTestSuite::ParamType>& info) {
std::string name = "ntiles" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});
......@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
DType dtype = TypeInfo<Type>::dtype;
Tensor input({ N, H }, dtype);
Tensor output({ H, N }, dtype);
Tensor input("input", { N, H }, dtype);
Tensor output("output", { H, N }, dtype);
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
......@@ -46,13 +46,13 @@ void performTest(const size_t N, const size_t H) {
nvte_transpose(input.data(), output.data(), 0);
compute_ref<Type>(input.cpu_dptr<Type>(), ref_output.get(), N, H);
compute_ref<Type>(input.rowwise_cpu_dptr<Type>(), ref_output.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(dtype);
compareResults("output", output, ref_output.get(), atol, rtol);
compareResults("output", output, ref_output.get(), true, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
......
This diff is collapsed.
This diff is collapsed.
......@@ -8,8 +8,9 @@ add_executable(test_util
../test_common.cu)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_compile_options(test_util PRIVATE -O2)
find_package(OpenMP REQUIRED)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX)
target_compile_options(test_util PRIVATE -O2 -fopenmp)
include(GoogleTest)
gtest_discover_tests(test_util)
gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600)
......@@ -27,9 +27,6 @@ def enable_fused_attn_after_hopper():
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]
......@@ -4,14 +4,19 @@
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional
import flax
import jax
import jax.numpy as jnp
import pytest
from utils import assert_allclose, assert_tree_like_allclose, sync_params_values
from utils import (
assert_allclose,
assert_tree_like_allclose,
dtype_tols,
sync_params_values,
)
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
......@@ -250,7 +255,13 @@ class BaseRunner:
target = sync_params_values(target, ref, self.transformations)
return ref, target
def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
def test_forward(
self,
data_shape: Tuple[int],
dtype: jnp.dtype,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
......@@ -264,9 +275,16 @@ class BaseRunner:
ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
assert_allclose(ref_out, test_out, **tols)
def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
def test_backward(
self,
data_shape: Tuple[int],
dtype: jnp.dtype,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
......@@ -302,11 +320,12 @@ class BaseRunner:
inputs, test_masks, test_params, test_others, test_layer
)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
assert_allclose(ref_out, test_out, **tols)
assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
_, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol)
assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
class EncoderRunner(BaseRunner):
......@@ -418,12 +437,12 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
......
......@@ -1387,18 +1387,26 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
def dtype_tols(
dtype: Union[DType, TEDType, np.dtype],
reference_value: float = 1.0,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> Dict[str, float]:
"""Expected numerical tolerance for a data type.
Args:
dtype: data type.
reference_value: reference value (default: 1).
rtol: override for relative tolerance estimate
atol: override for absolute tolerance estimate
Returns:
Dictionary with "rtol" and "atol" as keys
"""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Convert to JAX dtype if needed
if isinstance(dtype, TEDType):
dtype = {
......@@ -1416,7 +1424,11 @@ def dtype_tols(
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
return dict(rtol=0, atol=0)
if rtol is None:
rtol = 0.0
if atol is None:
atol = 0.0
return {"rtol": rtol, "atol": atol}
# Estimate floating-point error
finfo = jnp.finfo(dtype)
......@@ -1429,10 +1441,11 @@ def dtype_tols(
spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
ulp = max(spacing_high.item(), spacing_low.item())
return dict(
rtol=eps_relaxed,
atol=max(ulp, eps_relaxed),
)
if rtol is None:
rtol = eps_relaxed
if atol is None:
atol = max(ulp, eps_relaxed)
return {"rtol": rtol, "atol": atol}
def sync_params_values(dst, src, transformations, sep="/"):
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment