Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/recipe.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
}
template <typename InputType, typename OutputType>
void compute_amax_scale_ref(const InputType *data,
const size_t size,
float *amax_ptr, float *scale_ptr, float* scale_inv_ptr,
float max_fp8, float epsilon) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
}
*amax_ptr = current_max;
// compute scale from amax
float clamp_amax = current_max;
if (current_max <= epsilon){
clamp_amax = epsilon;
}
float scale = 1.f;
float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
return;
}
// use ieee_div in CPU
scale = max_fp8 / clamp_amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
scale_inv = 1.0f / scale;
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
}
// current tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
bool is_out_fp8 = isFp8Type(otype);
// find out max fp8 value
float max_fp8;
if (is_out_fp8){
switch (otype) {
case DType::kFloat8E5M2: {
max_fp8 = Quantized_Limits<fp8e5m2>::max();
} break;
case DType::kFloat8E4M3: {
max_fp8 = Quantized_Limits<fp8e4m3>::max();
} break;
default:
NVTE_ERROR("Invalid type.");
}
}
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype, true, false);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
// compute amax
float amax_to_check = 0.0f;
if (is_out_fp8){
nvte_compute_amax(input.data(), output_c.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(output_c.data(), config, 0);
// avoid atomic amax update in cuda cast kernels because of current per-tensor scaling
amax_to_check = output_c.amax();
output_c.set_tensor_amax_nullptr();
}
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
float ref_scale;
float ref_scale_inv;
if (is_out_fp8){
compute_amax_scale_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
full_size, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f);
}
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f );
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32);
compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32);
compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32);
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, 0.0f, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastCSTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastCSTestSuite, TestCastCS) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// current tensor scaling
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastCSTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastCSTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias(const IT *input_h,
const CT scale,
OT *output_c_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias(input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasTestSuite, TestCastDBias) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias_dgelu(const IT *input,
const IT *grad,
const CT scale,
OT *output_c,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT in_elt = static_cast<CT>(input[i * H + j]);
const CT in_grad = static_cast<CT>(grad[i * H + j]);
const CT elt = in_grad * static_cast<float>(dgelu(static_cast<float>(in_elt)));
const CT elt_abs = std::abs(elt);
// update amax
if (elt_abs > amax) {
amax = elt_abs;
}
output_c[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&grad);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
grad.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <omp.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IType, typename OType>
void compute_ref_cast_dgated_swiglu(const IType * const grad,
const IType * const input,
const float scale,
OType * const output,
float * const amax_ptr,
const size_t rows,
const size_t cols) {
float amax = 0;
const size_t stride = cols * 2;
#pragma omp parallel for reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
float grad_elt = static_cast<float>(grad[i * cols + j]);
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
float after_dgate = grad_elt * silu(silu_elt);
if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
output[i * stride + j] = static_cast<OType>(scale * after_dsilu);
output[i * stride + cols + j] = static_cast<OType>(scale * after_dgate);
}
}
*amax_ptr = amax;
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
std::vector<size_t> input_shape = shape;
input_shape[input_shape.size() - 1] *= 2;
const size_t input_size = product(input_shape);
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
Tensor grad("grad", shape, itype);
Tensor input("input", input_shape, itype);
Tensor output_c("output_c", input_shape, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(input_size);
nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax;
compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
rows,
cols);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{217, 256},
{1296},
{5, 4, 3, 160},
};
} // namespace
class CastSwiGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::vector<size_t>>> {};
TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
if (size.back() % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, CastSwiGLUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastSwiGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationType {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
template <typename InputType, typename OutputType, float (*OP)(const float)>
void scale_block(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
float* dbias,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float amax = 0.0f;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
dbias[j] += elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
amax = std::max(amax, std::abs(elt));
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal);
}
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
std::vector<float> output_dbias_fp32(cols, 0);
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, output_dbias_fp32.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x1(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const bool rowwise,
const bool colwise,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
if (shape.size() < 2 && colwise) {
GTEST_SKIP();
}
const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x2(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if (shape.size() < 2) {
GTEST_SKIP();
}
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c_rowwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_c_colwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(),
ref_output_c_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16},
{16, 48},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{256, 65536},
{2048, 6144},
{16384, 128},
{32768, 160},
{4096, 1632},
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
} // namespace
class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationType,
std::vector<size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationType Act_type = std::get<1>(GetParam());
const auto matrix_size = std::get<2>(GetParam());
const auto block_size = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const InputsFillCase fill_case = std::get<6>(GetParam());
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& Act_type != ActivationType::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
GTEST_SKIP();
}
const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations
ACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
} else {
DACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
}
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity: return "Identity";
case ActivationType::GeLU: return "GeLU";
case ActivationType::SiLU: return "SiLU";
case ActivationType::ReLU: return "ReLU";
case ActivationType::QGeLU: return "QGeLU";
case ActivationType::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
FusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios)),
[](const testing::TestParamInfo<FusedCastMXFP8TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param)) + "X" +
to_string(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + std::to_string(std::get<3>(info.param).first) +
"X" + std::to_string(std::get<3>(info.param).second) +
"X" + test::typeName(std::get<4>(info.param)) +
"X" + test::typeName(std::get<5>(info.param)) +
"X" + test::caseName(std::get<6>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <bool IS_DGATED, typename IType, typename OType>
void scale_block(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t scale_idx_gate,
float& thread_amax,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float block_amax = 0.0f;
float block_amax_gate = 0.0f;
const size_t stride = cols * 2;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
}
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
}
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate *
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
}
}
}
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x1(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride) {
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
float thread_amax = 0;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
}
}
}
if (thread_amax > amax) {
amax = thread_amax;
}
}
ref_amax = amax;
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x1(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
ref_output_scales[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output.get(),
ref_output_scales.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
true, true, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output_rowwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<OType[]> ref_output_colwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) {
ref_scales_rowwise[i] = 0;
}
for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) {
ref_scales_colwise[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
}
std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{1, 32},
{16, 64},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
{65536, 128},
{16384, 1632},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<bool> is_dgated_op = {
true,
false
};
} // namespace
class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase,
bool>> {};
TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const auto matrix_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const InputsFillCase fill_case = std::get<4>(GetParam());
const bool IS_DGATED = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) {
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} else {
if (IS_DGATED) {
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
}
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastMXFP8_GatedActTestSuite,
::testing::Combine(
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED");
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t,
const size_t N, const size_t H,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i * H + j] = OutputType(scale * current);
output_t[j * N + i] = OutputType(scale * current);
}
}
*amax = current_max;
}
// delayed tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype, true, true);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input);
setRandomScale(&output);
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax,
output.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8},
{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123
} // namespace
class CTTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTTestSuite, TestCastTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// delayed tensor scaling
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/recipe.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t,
const size_t N, const size_t H,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i * H + j] = OutputType(scale * current);
output_t[j * N + i] = OutputType(scale * current);
}
}
}
template <typename InputType, typename OutputType>
void compute_amax_scale_ref(const InputType *data,
const size_t N, const size_t H,
float *amax_ptr, float *scale_ptr, float* scale_inv_ptr,
float max_fp8, float epsilon) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
}
}
*amax_ptr = current_max;
// compute scale from amax
float clamp_amax = current_max;
if (current_max <= epsilon){
clamp_amax = epsilon;
}
float scale = 1.f;
float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
return;
}
// use ieee_div in CPU
scale = max_fp8 / clamp_amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
scale_inv = 1.0f / scale;
*scale_ptr = scale;
*scale_inv_ptr = scale_inv;
}
// current tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
bool is_out_fp8 = isFp8Type(otype);
// find out max fp8 value
float max_fp8;
if (is_out_fp8){
switch (otype) {
case DType::kFloat8E5M2: {
max_fp8 = Quantized_Limits<fp8e5m2>::max();
} break;
case DType::kFloat8E4M3: {
max_fp8 = Quantized_Limits<fp8e4m3>::max();
} break;
default:
NVTE_ERROR("Invalid type.");
}
}
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype, true, true);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input);
// compute amax
float amax_to_check = 0.0f;
if (is_out_fp8){
nvte_compute_amax(input.data(), output.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(output.data(), config, 0);
// avoid atomic amax update in cuda cast kernels because of current per-tensor scaling
amax_to_check = output.amax();
output.set_tensor_amax_nullptr();
}
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
float ref_scale;
float ref_scale_inv;
if (is_out_fp8){
compute_amax_scale_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
N, H, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f);
}
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, nullptr,
is_out_fp8 ? output.scale() : 1.0f );
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32);
compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32);
compareResults("scale", output.scale(), ref_scale, 0.0f, rtol_fp32);
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32);
compareResults("scale_inv_columnwise", output.columnwise_cpu_scale_inv_ptr<float>()[0], ref_scale_inv, 0.0f, rtol_fp32);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, 0.0f, rtol);
compareResults("output_t", output, ref_output_t.get(), false, 0.0f, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8},
{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123
} // namespace
class CTCSTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTCSTestSuite, TestCastTransposeCS) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// current tensor scaling
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTCSTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTCSTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias(const IT *input_h,
const CT scale,
OT *output_c_h,
OT *output_t_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
output_t_h[j * N + i] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input("input", {N, H}, itype);
Tensor output("output", {N, H}, otype, true, true);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr<IType>(),
output.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias(input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256}};
} // namespace;
class CTDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTDBiasTestSuite, TestCTDBias) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType>
CType dgelu(const CType cval) {
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const IT *gelu_input,
const CT scale,
OT *output_c,
OT *output_t,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input[i * H + j]);
const CT gelu_in = static_cast<CT>(gelu_input[i * H + j]);
elt = dgelu(gelu_in) * elt;
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c[i * H + j] = static_cast<OT>(scale * elt);
output_t[j * N + i] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input("input", {N, H}, itype);
Tensor gelu_input("gelu_input", {N, H}, itype);
Tensor output("output", {N, H}, otype, true, true);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&gelu_input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
gelu_input.rowwise_cpu_dptr<IType>(),
output.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
output.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256}};
} // namespace;
class CTDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t,
size_t>>> {};
TEST_P(CTDBiasDGeluTestSuite, TestCTDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType, typename IType>
inline CType gelu(const IType val) {
CType cval = val;
return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}
template <typename CType, typename IType>
inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dgated_gelu(const IT *grad_h, const IT *input_h, const CT scale,
OT *output_c_h, OT *output_t_h, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const size_t col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad_elt = CT(grad_h[i * H + j]);
CT gelu_elt = CT(input_h[i * col + j]);
CT gate_elt = CT(input_h[i * col + H + j]);
CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
amax = std::abs(after_dgelu) > amax ? std::abs(after_dgelu) : amax;
amax = std::abs(after_dgate) > amax ? std::abs(after_dgate) : amax;
output_c_h[i * col + j] = static_cast<OT>(scale * after_dgelu);
output_c_h[i * col + H + j] = static_cast<OT>(scale * after_dgate);
output_t_h[j * N + i] = static_cast<OT>(scale * after_dgelu);
output_t_h[(j + H) * N + i] = static_cast<OT>(scale * after_dgate);
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad("grad", {N, H}, itype);
Tensor input("input", {N, H * 2}, itype);
Tensor output("output", {N, H * 2}, otype, true, true);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu_cast_transpose(grad.data(), input.data(), output.data(), 0);
CType ref_amax;
compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
output.scale(), ref_output_c.get(), ref_output_t.get(),
&ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400}, {4096, 2048}, {768, 2816},
{256, 5120}, {128, 10240}, {256, 256}};
} // namespace
class DGeGLUCTTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(DGeGLUCTTestSuite, TestDGeGLUCT) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, DGeGLUCTTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<DGeGLUCTTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/softmax.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
using compute_t = float;
template <typename Type>
void compute_single_head_fwd(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t max_value = static_cast<compute_t>(-10'000.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = scaling_factor * static_cast<compute_t>(data_in[offset + j]);
buff[offset + j] = tmp;
max_value = std::max(max_value, tmp);
}
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = std::exp(buff[offset + j] - max_value);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j] / accumulator;
softmax_out[offset + j] = static_cast<Type>(tmp);
} else {
softmax_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_single_head_bwd(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = static_cast<compute_t>(softmax_in[offset + j])
* static_cast<compute_t>(grad_in[offset + j]);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j]
- static_cast<compute_t>(softmax_in[offset + j]) * accumulator;
grad_out[offset + j] = static_cast<Type>(scaling_factor * tmp);
} else {
grad_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_fwd_ref(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_fwd(softmax_out + offset, data_in + offset,
buff + offset, scaling_factor, rows, cols);
}
}
}
template <typename Type>
void compute_bwd_ref(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
}
// Query Sequence Length = rows
// Key Sequence Length = cols
template <typename Type>
void performTest(
const size_t batches,
const size_t heads,
const size_t rows,
const size_t cols,
float scaling_factor)
{
using namespace test;
DType itype = TypeInfo<Type>::dtype;
Tensor data_in("data_in", { batches, heads, rows, cols }, itype);
Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype);
Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype);
Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype);
Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype);
const size_t elements_total = batches * heads * rows * cols;
std::unique_ptr<Type[]> softmax_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<Type[]> grads_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<compute_t[]> compute_buffer = std::make_unique<compute_t[]>(elements_total);
fillUniform(&data_in);
fillUniform(&softmax_in);
fillUniform(&grads_in);
nvte_scaled_aligned_causal_masked_softmax_forward(
data_in.data(), softmax_out.data(), scaling_factor, 0);
nvte_scaled_aligned_causal_masked_softmax_backward(
grads_in.data(), softmax_in.data(), grads_out.data(), scaling_factor, 0);
// Reference implementations
compute_fwd_ref(softmax_out_ref.get(), data_in.rowwise_cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
compute_bwd_ref(grads_out_ref.get(), grads_in.rowwise_cpu_dptr<Type>(), softmax_in.rowwise_cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(itype);
if(itype == DType::kBFloat16) {
atol = 1e-3;
}
compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), true, atol, rtol);
compareResults("softmax_bwd", grads_out, grads_out_ref.get(), true, atol, rtol);
}
// [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor]
std::vector<std::tuple<size_t, size_t, size_t, size_t, float>> test_cases = {
{ 1, 1, 1, 16, -1.0f},
{ 1, 2, 17, 32, 0.8f},
{ 2, 1, 37, 112, 1.0f},
{ 2, 4, 127, 128, -0.2f},
{ 8, 6, 128, 256, 1.3f},
{ 1, 4, 270, 256, 0.8f},
{ 2, 2, 512, 512, -1.5f},
{ 1, 2, 819, 1024, 2.1f},
{ 1, 2, 281, 1024, 0.2f},
{ 1, 2, 277, 1024, -2.1f},
{ 1, 2, 127, 1024, 1.1f},
{ 2, 2, 107, 2048, 0.4f},
{ 2, 1, 103, 2048, -3.0f},
{ 2, 2, 101, 2048, 2.6f},
{ 1, 1, 1024, 4096, 0.6f},
{ 1, 2, 61, 4096, 0.6f},
{ 1, 2, 59, 4096, -4.9f},
{ 1, 2, 53, 4096, 3.5f},
{ 1, 1, 37, 8192, 0.7f},
{ 1, 1, 31, 8192, -5.8f},
{ 1, 1, 29, 8192, 4.4f},
{ 1, 1, 23, 12288, 0.8f},
{ 1, 1, 19, 12288, -6.7f},
{ 1, 1, 17, 12288, 3.3f},
{ 1, 1, 13, 16384, 0.9f},
{ 1, 1, 11, 16384, -7.6f},
{ 1, 1, 7, 16384, 6.2f}};
} // namespace
class CausalSoftmaxTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType,
std::tuple<size_t, size_t, size_t, size_t, float>>> {};
TEST_P(CausalSoftmaxTestSuite, TestCausalSoftmax) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const auto size = std::get<1>(GetParam());
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
const float scaling_factor = std::get<4>(size);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
performTest<InputType>(batches, heads, query_seq_len, key_seq_len, scaling_factor);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CausalSoftmaxTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat16, DType::kBFloat16),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CausalSoftmaxTestSuite::ParamType>& info) {
const auto size = std::get<1>(info.param);
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
std::string scaling_factor = std::to_string(std::get<4>(size));
for (char& c : scaling_factor) {
if (c == '-') { c = 'N'; }
if (c == '.') { c = 'p'; }
}
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::to_string(batches) + "X" +
std::to_string(heads) + "X" +
std::to_string(query_seq_len) + "X" +
std::to_string(key_seq_len) + "X" +
scaling_factor;
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <limits>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename InputType, typename OutputType>
void dequantize_block(const InputType* input,
OutputType* output,
fp8e8m0* scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols)
{
const fp8e8m0 biased_exponent = scales[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS);
const float elem_scale = block_scale;
// Dequantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float elt = static_cast<float>(input[idx]);
output[idx] = static_cast<OutputType>(elt * elem_scale);
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_x1(const InputType* input,
OutputType* output,
fp8e8m0* scales,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_x2(const InputType* input,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise)
{
compute_ref_x1<InputType, OutputType>(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType>(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
void generate_scales(fp8e8m0 * const scales_ref,
fp8e8m0 * const scales,
const size_t blocks_num,
std::mt19937& gen,
std::uniform_int_distribution<fp8e8m0> dis)
{
for (size_t i = 0; i < blocks_num; ++i) {
const fp8e8m0 val = dis(gen);
scales_ref[i] = val;
scales[i] = val;
}
}
template<typename InputType>
void generate_data(InputType * const data,
const size_t rows,
const size_t cols,
std::mt19937& gen,
std::uniform_real_distribution<>& dis,
std::uniform_real_distribution<>& dis_sign)
{
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
const bool is_negative = (dis_sign(gen) < 0.0);
double val = dis(gen);
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
}
}
}
template<typename InputType>
void fill_tensor_data(Tensor& input,
fp8e8m0 * const scales_rowwise,
fp8e8m0 * const scales_colwise,
const bool is_rowwise_scaling,
const bool is_colwise_scaling,
const size_t rows,
const size_t cols,
const size_t blocks_num_rowwise,
const size_t blocks_num_colwise)
{
const double minAbs = Numeric_Traits<InputType>::minNorm;
const double maxAbs = Numeric_Traits<InputType>::maxNorm;
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(minAbs, maxAbs);
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
std::uniform_int_distribution<fp8e8m0> int_dis(0, 255);
if (is_rowwise_scaling) {
generate_scales(scales_rowwise, input.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), blocks_num_rowwise, gen, int_dis);
generate_data(input.rowwise_cpu_dptr<InputType>(), rows, cols, gen, dis, dis_sign);
}
if (is_colwise_scaling) {
generate_scales(scales_colwise, input.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), blocks_num_colwise, gen, int_dis);
generate_data(input.columnwise_cpu_dptr<InputType>(), rows, cols, gen, dis, dis_sign);
}
input.from_cpu();
}
// Dequantize along single dimension (either row- or columnwise)
template <typename InputType, typename OutputType>
void performTest_x1(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;
const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);
const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise;
const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise;
const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise;
const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise;
Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
// Output data are written to the rowwise ptr regardless of the scaling direction
Tensor output("output", { rows, cols }, otype, true, false);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> scales = std::make_unique<fp8e8m0[]>(blocks_num);
fill_tensor_data<InputType>(input, scales.get(), scales.get(), rowwise, colwise, rows, cols,
blocks_num_rowwise, blocks_num_colwise);
nvte_dequantize(input.data(), output.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
InputType * data_ptr = rowwise
? input.rowwise_cpu_dptr<InputType>()
: input.columnwise_cpu_dptr<InputType>();
compute_ref_x1<InputType, OutputType>(data_ptr,
ref_output.get(),
scales.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), true, atol, rtol);
}
// Dequantize along single dimension (either row- or columnwise)
template <typename InputType, typename IntermediateType>
void performTest_quantize_then_dequantize(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
using EncodingType = fp32;
DType in_type = TypeInfo<InputType>::dtype;
DType intermed_type = TypeInfo<IntermediateType>::dtype;
DType out_type = TypeInfo<InputType>::dtype;
std::unique_ptr<InputType[]> input_cpu = std::make_unique<InputType[]>(rows * cols);
std::unique_ptr<IntermediateType[]> quantized_cpu = std::make_unique<IntermediateType[]>(rows * cols);
std::unique_ptr<InputType[]> output_cpu = std::make_unique<InputType[]>(rows * cols);
// input --> quantized --> output (dequantized)
// input == output
Tensor input("input", { rows, cols }, in_type);
Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
// Output data are written to the rowwise ptr regardless of the scaling direction
Tensor output("output", { rows, cols }, out_type, true, false);
// fillCase<EncodingType>(&input, InputsFillCase::minNorm_to_maxNorm);
fillCase<EncodingType>(&input, InputsFillCase::uniform);
const size_t copy_size = sizeof(InputType) * rows * cols;
cudaMemcpy(input_cpu.get(), input.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost);
nvte_quantize(input.data(), quantized.data(), 0);
cudaDeviceSynchronize();
const size_t copy_size_quantized = sizeof(IntermediateType) * rows * cols;
if (rowwise) {
cudaMemcpy(quantized_cpu.get(), quantized.rowwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost);
}
if (colwise) {
cudaMemcpy(quantized_cpu.get(), quantized.columnwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost);
}
nvte_dequantize(quantized.data(), output.data(), 0);
cudaDeviceSynchronize();
cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost);
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(intermed_type);
compareResults("Quantize-Dequantize", input, output_cpu.get(), true, atol, rtol);
}
// Dequantize along both dimensions (row- and columnwise)
template <typename InputType, typename OutputType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols)
{
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;
const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t scales_stride_colwise = blocks_X_colwise;
const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise;
const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise;
Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING);
Tensor output("output", { rows, cols }, otype);
std::unique_ptr<OutputType[]> ref_output_rowwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_colwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_num_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_num_colwise);
constexpr bool rowwise = true;
constexpr bool colwise = true;
fill_tensor_data<InputType>(input, ref_scales_rowwise.get(), ref_scales_colwise.get(),
rowwise, colwise, rows, cols, blocks_num_rowwise, blocks_num_colwise);
nvte_dequantize(input.data(), output.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> tensor_dims = {
{1, 16},
{16, 48},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
// {2048, 12288},
// {65536, 128},
// {16384, 1632},
// {16384, 6144},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
// {32, 32},
};
} // namespace
class DequantizeMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
bool>> {};
TEST_P(DequantizeMXFP8TestSuite, TestDequantizeMXFP8)
{
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const auto tensor_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const bool quantize_then_dequantize = std::get<4>(GetParam());
const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;
// Skip tests for dequantization along both dimensions
if (rowwise && colwise) {
GTEST_SKIP();
}
// Skip cases with invalid alignment
if (rowwise && tensor_size.second % 32 != 0) {
GTEST_SKIP();
}
if (colwise && tensor_size.first % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
if (quantize_then_dequantize) {
// Mind the order of the Output/Input template parameters
performTest_quantize_then_dequantize<OutputType, InputType>(
tensor_size.first, tensor_size.second, rowwise, colwise);
} else {
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType>(tensor_size.first, tensor_size.second,
rowwise, colwise);
} else {
performTest_x2<InputType, OutputType>(tensor_size.first, tensor_size.second,
block_size.first, block_size.second);
}
}
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DequantizeMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(tensor_dims),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(false)),
[](const testing::TestParamInfo<DequantizeMXFP8TestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
(std::get<4>(info.param) ? "QD" : "D");
return name;
}
);
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_c_list,
std::vector<std::vector<OutputType>>& output_t_list,
const std::vector<float>& scale_list,
std::vector<float>& amax_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output_c = output_c_list[tensor_id];
auto& output_t = output_t_list[tensor_id];
const compute_t scale = scale_list[tensor_id];
compute_t& amax = amax_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
amax = -1e100;
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(scale * x);
amax = fmaxf(amax, fabsf(x));
output_c[i * width + j] = y;
output_t[j * height + i] = y;
}
}
}
}
template <typename InputType, typename OutputType>
void performTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list;
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors);
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype));
output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id),
{ height, width }, otype, true, true));
auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);
ref_input_list.emplace_back(height*width);
ref_output_c_list.emplace_back(height*width);
ref_output_t_list.emplace_back(width*height);
std::copy(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_scale_list[tensor_id] = output.scale();
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_cast_transpose(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
0);
// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_c_list,
ref_output_t_list,
ref_scale_list,
ref_amax_list,
ref_height_list,
ref_width_list);
// Check correctness
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax",
output_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_list[tensor_id].rowwise_scale_inv(),
1.f / output_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
output_list[tensor_id],
ref_output_c_list[tensor_id].data(),
true, atol, rtol);
compareResults("output_t",
output_list[tensor_id],
ref_output_t_list[tensor_id].data(),
false, atol, rtol);
}
}
} // namespace
class MultiCastTransposeTestSuite
: public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType>> {};
TEST_P(MultiCastTransposeTestSuite, TestMultiCastTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiCastTransposeTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types)),
[](const testing::TestParamInfo<MultiCastTransposeTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
for (size_t i = 0; i < padded_height; ++i) {
if (i < height) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
} else {
for (size_t j = 0; j < width; ++j) {
output[i * width + j] = static_cast<OutputType>(0.f);
}
}
}
}
}
template <typename InputType, typename OutputType>
void performTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list, output_t_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype));
output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype));
auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);
ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);
std::copy(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
ref_padded_height_list.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
true,
atol, rtol);
}
}
} // namespace
class MultiPaddingTestSuite
: public ::testing::TestWithParam<
transformer_engine::DType> {};
TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiPaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){
using compute_t = float;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor dz("dz", { N, H }, wtype);
Tensor dx("dx", { N, H }, itype);
Tensor dgamma("dgamma", { H }, wtype);
Tensor dbeta("dbeta", { H }, wtype);
Tensor workspace_fwd, workspace_bwd;
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&beta);
setRandomScale(&z);
fillUniform(&dz);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
std::unique_ptr<WeightType[]> ref_dbeta = std::make_unique<InputType[]>(H);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);
}
// Forward kernel
float epsilon = 1e-5;
if (norm_type == LayerNorm){
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
}
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(false);
nvte_enable_cudnn_norm_bwd(false);
}
// Reference implementations
// use the GPU stats to tighten the tolerances
mu.to_cpu();
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(),
ref_output.get(),
mu.rowwise_cpu_dptr<float>(),
rsigma.rowwise_cpu_dptr<float>(),
N, H,
&ref_amax,
ref_scale,
zero_centered_gamma,
use_cudnn);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma,
use_cudnn);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
auto [atol, rtol] = getTolerances(otype);
if (otype == DType::kFloat32) {
atol = 5e-7;
}
compareResults("output", z, ref_output.get(), true, atol, rtol);
double atol_bwd = 5e-4;
double rtol_bwd = 5e-4;
compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd);
compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{71, 229},
{29, 541},
{768, 6144},
{2048, 12288},
};
} // namespace
class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool>> {};
TEST_P(NormTestSuite, TestNorm) {
using namespace transformer_engine;
using namespace test;
const bool use_cudnn = std::get<0>(GetParam());
const NormType norm_type = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
NormTestSuite,
::testing::Combine(
::testing::Values(true, false),
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
std::string name =
backend +
normToString.at(std::get<1>(info.param)) + "_" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <map>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
using fp8e8m0 = byte;
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RMSNorm"}
};
template <typename InputType, typename ScaleType, typename OutputType>
void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr,
size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){
const size_t block_size_Y = scaling_mode_x; // mind the mapping Y <-- x
const size_t block_size_X = scaling_mode_y; // and X <-- y
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X;
#pragma omp parallel for proc_bind(spread) schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X;
// TODO: padded SFs i.e. (4,128)
const float scale_inv = exp2f(static_cast<float>(scale_ptr[mx_scale_idx]) - FP32_EXPONENT_BIAS);
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float elem = static_cast<float>(input_ptr[idx]);
output_ptr[idx] = static_cast<float>(elem * scale_inv);
}
}
}
}
}
}
template <typename InputType, typename ScaleType>
void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
{
input.to_cpu();
auto scaling_mode = input.scaling_mode();
assert(input.rowwise_shape().ndim == 2);
if (is_training) {
assert(input.columnwise_shape().ndim == 2);
}
dequantize_1x_kernel(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_scale_inv_ptr<ScaleType>(),
output.rowwise_cpu_dptr<float>(),
input.rowwise_shape().data[0], input.rowwise_shape().data[1],
1, 32);
if (is_training)
dequantize_1x_kernel(input.columnwise_cpu_dptr<InputType>(),
input.columnwise_cpu_scale_inv_ptr<ScaleType>(),
output.columnwise_cpu_dptr<float>(),
input.columnwise_shape().data[0], input.columnwise_shape().data[1],
32, 1);
}
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
compute_t m;
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
OutputType* output,
const bool zero_centered_gamma){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1.0;
}
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = tmp;
}
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor workspace;
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&beta);
// Forward kernel
float epsilon = 1e-5;
if (norm_type == NormType::LayerNorm){
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma,
0);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
// Reference implementations
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_output = std::make_unique<float[]>(N * H);
compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
// use the GPU stats to tighten the tolerances
float *ref_mu_ptr, *ref_rsigma_ptr;
if (is_training){
mu.to_cpu();
rsigma.to_cpu();
ref_mu_ptr = mu.rowwise_cpu_dptr<float>();
ref_rsigma_ptr = rsigma.rowwise_cpu_dptr<float>();
} else {
ref_mu_ptr = ref_mu.get();
ref_rsigma_ptr = ref_rsigma.get();
}
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(),
ref_mu_ptr,
ref_rsigma_ptr,
N, H,
ref_output.get(),
zero_centered_gamma);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
if (is_training){
compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
}
float atol, rtol;
if (otype == DType::kFloat8E5M2){
atol = 1.25e-1;
rtol = 1.25e-1;
} else if (otype == DType::kFloat8E4M3){
if (itype == DType::kBFloat16){
atol = 7e-2;
rtol = 7e-2;
} else {
atol = 6.25e-2;
rtol = 6.25e-2;
}
}
compareResults("output_rowwise", dequantized_output, ref_output.get(), true, atol, rtol, false);
if (is_training)
compareResults("output_colwise", dequantized_output, ref_output.get(), false, atol, rtol, false);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{32, 32},
{768, 2304},
{2048, 12288},
};
std::vector<NormType> norms = {
NormType::LayerNorm,
NormType::RMSNorm
};
} // namespace
class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool, bool>> {};
TEST_P(MxNormTestSuite, TestMxNorm) {
using namespace transformer_engine;
using namespace test;
const NormType norm_type = std::get<0>(GetParam());
const DType input_type = std::get<1>(GetParam());
const DType output_type = std::get<2>(GetParam());
const auto size = std::get<3>(GetParam());
const bool zero_centered_gamma = std::get<4>(GetParam());
const bool is_training = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxNormTestSuite,
::testing::Combine(
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(true, false),
::testing::Values(true, false)),
[](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) {
std::string name = normToString.at(std::get<0>(info.param)) + "_" +
test::typeName(std::get<1>(info.param)) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
std::to_string(std::get<3>(info.param).first) + "X" +
std::to_string(std::get<3>(info.param).second) + "X" +
std::to_string(std::get<4>(info.param)) + "out" +
std::to_string(int(std::get<5>(info.param)) + 1) + "x";
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref_q(const InputType *data, OutputType *output,
const size_t N,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
if (std::is_same<OutputType, test::fp8e4m3>::value ||
std::is_same<OutputType, test::fp8e5m2>::value) {
output[i] = OutputType(scale * current);
} else {
output[i] = OutputType(current);
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_dq(const InputType *data, OutputType *output,
const size_t N, float scale_inv) {
using compute_t = float;
for (size_t i = 0; i < N; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
output[i] = OutputType(scale_inv * current);
}
}
template <typename InputType, typename OutputType>
void performTestQ(const size_t N) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(&input);
setRandomScale(&output);
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
compute_ref_q<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output.get(),
N, &ref_amax, output.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), true, atol, rtol);
}
template <typename InputType, typename OutputType>
void performTestDQ(const size_t N) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(&input);
nvte_dequantize(input.data(), output.data(), 0);
compute_ref_dq<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output.get(),
N, input.rowwise_scale_inv());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_dq", output, ref_output.get(), true, atol, rtol);
}
std::vector<size_t> qdq_test_cases = {2048* 12288,
768 * 1024,
256 * 65536,
65536 * 128,
257 * 259,
128*128+1};
} //namespace
class QDQTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
size_t>> {};
TEST_P(QDQTestSuite, TestQ) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const size_t N = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestQ<InputType, OutputType>(N);
);
);
}
TEST_P(QDQTestSuite, TestDQ) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const size_t N = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestDQ<OutputType, InputType>(N);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
QDQTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(qdq_test_cases)),
[](const testing::TestParamInfo<QDQTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstdint>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/swizzle.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
constexpr int MAT_TILE_DIM_M = 128;
constexpr int MAT_TILE_DIM_K = 128;
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K, bool row_scaling>
void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output,
const size_t M, const size_t K) {
constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4;
constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4;
constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
int tile_id_m = m / SF_TILE_DIM_M;
int tile_id_k = k / SF_TILE_DIM_K;
int m_in_tile = m % SF_TILE_DIM_M;
int k_in_tile = k % SF_TILE_DIM_K;
int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M;
int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile;
int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE;
int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile;
if constexpr(row_scaling)
h_output[out_index] = h_input[k + m * K];
else
h_output[out_index] = h_input[k * M + m];
}
}
}
void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;
int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
}
DType dtype = DType::kFloat8E4M3;
const size_t M = num_tiles_M * MAT_TILE_DIM_M;
const size_t K = num_tiles_K * MAT_TILE_DIM_K;
const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};
const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y};
std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0};
Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
fillUniform(&input);
std::unique_ptr<uint8_t[]> ref_output = std::make_unique<uint8_t[]>(scale_shape[0] * scale_shape[1]);
nvte_swizzle_scaling_factors(input.data(), output.data(), 0);
if (rowwise)
compute_ref_swizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0], scale_shape[1]);
else
compute_ref_swizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[1], scale_shape[0]);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
output.to_cpu();
if (rowwise) {
compareResults("output_swizzle", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
} else {
compareResults("output_swizzle", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
}
}
class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};
TEST_P(SwizzleTestSuite, TestSwizzle) {
using namespace transformer_engine;
using namespace test;
const auto num_tiles = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());
performTestSwizzle1D(num_tiles.first, num_tiles.second,
scaling_mode.first, scaling_mode.second,
transa);
}
namespace {
std::vector<std::pair<int, int>> num_tiles = {
{1, 1},
{1, 132},
{132, 1},
{65, 256},
{65, 257},
{65, 258},
{65, 259},
};
std::vector<std::pair<bool, bool>> scaling_mode = {
{true, false},
{false, true}
};
std::vector<bool> transa = {true, false};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(num_tiles),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<SwizzleTestSuite::ParamType>& info) {
std::string name = "ntiles" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename Type>
void compute_ref(const Type *data, Type *output,
const size_t N, const size_t H) {
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
output[j * N + i] = data[i * H + j];
}
}
}
template <typename Type>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType dtype = TypeInfo<Type>::dtype;
Tensor input("input", { N, H }, dtype);
Tensor output("output", { H, N }, dtype);
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
fillUniform(&input);
nvte_transpose(input.data(), output.data(), 0);
compute_ref<Type>(input.rowwise_cpu_dptr<Type>(), ref_output.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(dtype);
compareResults("output", output, ref_output.get(), true, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8},
{1223, 1583}, // Primes 200, 250
{1, 541}, // Prime 100
{1987, 1}}; // Prime 300
} // namespace
class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(TTestSuite, TestTranspose) {
using namespace transformer_engine;
using namespace test;
const DType type = std::get<0>(GetParam());
const auto size = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
performTest<T>(size.first, size.second);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
TTestSuite,
::testing::Combine(
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<TTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second);
return name;
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment