Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "test_common.h"
#include <algorithm>
#include <memory>
#include <random>
#include <cassert>
#include <cmath>
#include <string>
#include <gtest/gtest.h>
#include <omp.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
namespace test {
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
"/" + tensor_name;
return std::hash<std::string>{}(full_name);
}
std::vector<DType> all_fp_types = {DType::kFloat32,
DType::kFloat16,
DType::kBFloat16,
DType::kFloat8E5M2,
DType::kFloat8E4M3};
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
if (s1.ndim != s2.ndim) return false;
for (size_t i = 0; i < s1.ndim; ++i) {
if (s1.data[i] != s2.data[i]) return false;
}
return true;
}
size_t typeToSize(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{
return TypeInfo<T>::size;
});
}
const std::string &typeName(DType type) {
static const std::unordered_map<DType, std::string> name_map = {
{DType::kByte, "byte"},
{DType::kInt32, "int32"},
{DType::kInt64, "int64"},
{DType::kFloat32, "float32"},
{DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}};
return name_map.at(type);
}
const std::string& caseName(InputsFillCase type) {
static const std::unordered_map<InputsFillCase, std::string> name_map = {
{InputsFillCase::uniform, "uniform"},
{InputsFillCase::zeros, "zeros"},
{InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
{InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
{InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
return name_map.at(type);
}
size_t product(const NVTEShape &shape, size_t begin, size_t end) {
size_t ret = 1;
NVTE_CHECK(end <= shape.ndim);
for (size_t i = begin; i < end; ++i) {
ret *= shape.data[i];
}
return ret;
}
size_t product(const NVTEShape &shape) {
return product(shape, 0, shape.ndim);
}
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
size_t ret = 1;
NVTE_CHECK(end <= shape.size());
for (size_t i = begin; i < end; ++i) {
ret *= shape[i];
}
return ret;
}
size_t product(const std::vector<size_t>& shape) {
return product(shape, 0, shape.size());
}
size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y));
}
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
size_t type_size;
};
NVTEShape convertShape(const std::vector<size_t>& shape) {
return {shape.data(), shape.size()};
}
std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
const NVTEScalingMode scaling_mode) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
scale_inv_meta ret;
ret.shape = {1};
ret.type = DType::kFloat32;
ret.type_size = sizeof(float);
return {ret, ret};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul,4ul};
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim,
static_cast<size_t>(1)),
alignment) * alignment;
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim,
static_cast<size_t>(32)),
alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim,
static_cast<size_t>(32)),
alignment) * alignment;
alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim,
static_cast<size_t>(1)),
alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t);
ret_colwise.type_size = sizeof(uint8_t);
return {ret_rowwise, ret_colwise};
}
NVTE_ERROR("Invalid scaling mode!");
}
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode) {
name_ = name;
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise;
size_t s = typeToSize(type);
size_t total_size = product(shape) * s;
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
cpu_data_columnwise_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
rowwise_scale_inv_cpu_data_ = nullptr;
columnwise_scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
if (columnwise) {
NVTE_CHECK(shape.ndim >= 2);
}
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
} else {
// Same shape for MX
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
}
if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data();
columnwise_shape.ndim = columnwise_shape_vec.size();
}
tensor_ = TensorWrapper(scaling_mode);
if (total_size != 0) {
if (rowwise) {
cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*)
cudaMemset(dptr_rowwise, 0, total_size);
cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
}
if (columnwise) {
cudaMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*)
cudaMemset(dptr_columnwise, 0, total_size);
cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
}
}
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
amax_cpu_data_ = std::make_shared<float>(0);
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*)
if (rowwise) {
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
std::vector<size_t>{1});
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
if (columnwise) {
tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
std::vector<size_t>{1});
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape,
tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) {
cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape);
}
if (columnwise) {
cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*)
cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape);
}
}
}
}
void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr,
size,
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr,
size,
cudaMemcpyDeviceToHost);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(),
sizeof(float),
cudaMemcpyDeviceToHost);
}
cudaMemcpy(scale_cpu_data_.get(),
tensor_.scale(),
sizeof(float),
cudaMemcpyDeviceToHost);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr,
scale_size,
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr,
scale_size,
cudaMemcpyDeviceToHost);
}
}
}
void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice);
}
if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
}
}
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale;
from_cpu();
}
}
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_);
}
if (columnwise_) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1){
rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{
std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){
scale_inv_ptr[i] = dis(gen_);
}
}
}
if (columnwise_) {
auto num_scales = product(colwise_scale_meta.shape);
if (num_scales == 1){
columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{
std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){
scale_inv_ptr[i] = dis(gen_);
}
}
}
from_cpu();
}
}
void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr,
static_cast<DType>(my_rowwise_data.dtype),
my_rowwise_data.shape);
auto my_columnwise_data = tensor_.get_columnwise_data();
new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
static_cast<DType>(my_columnwise_data.dtype),
my_columnwise_data.shape);
auto other_amax = other.tensor_.get_amax();
new_tensor.set_amax(other_amax.data_ptr,
static_cast<DType>(other_amax.dtype),
other_amax.shape);
auto other_scale = other.tensor_.get_scale();
new_tensor.set_scale(other_scale.data_ptr,
static_cast<DType>(other_scale.dtype),
other_scale.shape);
auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
static_cast<DType>(other_row_scale_inv.dtype),
other_row_scale_inv.shape);
auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
static_cast<DType>(other_col_scale_inv.dtype),
other_col_scale_inv.shape);
tensor_ = std::move(new_tensor);
to_cpu();
}
}
using std::to_string;
template <typename T>
std::string to_string(const std::vector<T> &v) {
std::string s = "[";
for (const auto x : v) {
s += to_string(x) + ", ";
}
s.pop_back();
s.pop_back();
return s + "]";
}
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
std::vector<size_t> ret;
size_t current_i = i;
for (size_t current = shape.ndim - 1;
current > 0;
--current) {
ret.push_back(current_i % shape.data[current]);
current_i /= shape.data[current];
}
ret.push_back(current_i);
std::reverse(ret.begin(), ret.end());
return ret;
}
void compareResults_sequential(const std::string &name, const Tensor &test,
const void *ref, const bool rowwise,
double atol, double rtol, bool if_on_gpus) {
if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
for (size_t i = 0; i < N; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && test.dtype() == DType::kFloat32;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
}
);
}
template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
const size_t N, const double atol, const double rtol) {
int first_mismatch_idx = N;
bool is_mismatch_found = false;
#pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \
reduction(min: first_mismatch_idx) proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
if (is_mismatch_found) { // early escape of the omp thread
continue;
}
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && (data_type == DType::kFloat32);
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion && i < first_mismatch_idx) {
first_mismatch_idx = i;
is_mismatch_found = true;
}
}
return first_mismatch_idx;
}
void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) {
if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol);
if (i != N) {
const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]);
std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
}
);
}
void compareResults(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) {
constexpr bool sequential = false;
if constexpr (sequential) {
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus);
} else {
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus);
}
}
void compareResults(const std::string &name, const float test, const float ref,
double atol, double rtol) {
double t = static_cast<double>(test);
double r = static_cast<double>(ref);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r;
}
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol) {
size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
size_t n_mismatches = 0;
std::vector<size_t> mismatch_indices;
for (int i = 0; i < N; i++){
bool mismatch = test[i] != ref[i];
if (mismatch){
n_mismatches++;
mismatch_indices.push_back(i);
}
if (n_mismatches > max_mismatches){
std::cout << "Error in " << name << std::endl;
for (auto &index : mismatch_indices)
std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
<< static_cast<int>(ref[i]) << std::endl;
GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
}
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride)
{
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[idx]) << " vs "
<< static_cast<int>(ref[idx]) << " at index " << idx;
}
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N)
{
for (int i = 0; i < N; i++) {
ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[i]) << " vs "
<< static_cast<int>(ref[i]) << " at index " << i;
}
}
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
return {1e-6, 5e-6};
case DType::kFloat16:
return {1e-5, 1e-3};
case DType::kBFloat16:
return {1e-5, 1e-2};
case DType::kFloat8E4M3:
case DType::kFloat8E5M2:
case DType::kFloat8E8M0:
return {1e-2, 1e-2};
default:
NVTE_CHECK("Invalid type!");
}
return {0, 0};
}
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#pragma omp parallel proc_bind(spread)
{
std::mt19937 gen_local = *gen;
gen_local.discard(omp_get_thread_num() * 599);
std::uniform_real_distribution<> dis(-2.0, 1.0);
#pragma omp for schedule(static)
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<T>(dis(gen_local));
}
}
gen->discard(size);
}
void fillUniform(Tensor *t) {
if (t->rowwise()) {
const size_t size = product(t->rowwise_shape());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
{
T *data = t->rowwise_cpu_dptr<T>();
generate_data_uniformly(data, size, &(t->gen()));
}
);
} else {
const size_t size = product(t->columnwise_shape());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
{
T *data = t->columnwise_cpu_dptr<T>();
generate_data_uniformly(data, size, &(t->gen()));
}
);
}
std::uniform_real_distribution<> dis(-2.0, 1.0);
t->set_scale_inv(dis(t->gen()));
t->from_cpu();
}
template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
const size_t size = product(t->rowwise_shape());
const size_t rows = t->rowwise_shape().data[0];
const size_t cols = t->rowwise_shape().data[1];
if constexpr (Case == InputsFillCase::zeros) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<InputType>(0);
}
});
} else {
double minAbs = -2.0;
double maxAbs = 1.0;
if constexpr (Case != InputsFillCase::uniform) {
minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
}
std::uniform_real_distribution<> dis(minAbs, maxAbs);
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
const bool is_negative = (dis_sign(t->gen()) < 0.0);
double val = dis(t->gen());
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
}
}
});
}
t->set_scale_inv(1.0);
t->from_cpu();
}
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
switch (fill_case) {
case InputsFillCase::uniform:
fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
case InputsFillCase::zeros:
fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
case InputsFillCase::zero_to_minNorm:
fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
case InputsFillCase::minNorm_to_maxNorm:
fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
case InputsFillCase::maxNorm_to_inf:
fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
}
}
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
void setRandomScale(Tensor *t) {
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float scale = dis(t->gen());
t->set_scale(scale);
}
void setRandomScaleInv(Tensor *t) {
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float scale_inv = dis(t->gen());
t->set_scale_inv(scale_inv);
}
bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
int32_t getDeviceComputeCapability()
{
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
return 10 * deviceProp.major + deviceProp.minor;
}
size_t first_dimension(const std::vector<size_t> &shape) {
if (shape.size() == 0) return 1;
if (shape.size() == 1) return 1;
return product(shape, 0, shape.size() - 1);
}
size_t last_dimension(const std::vector<size_t> &shape) {
if (shape.size() == 0) return 1;
return shape[shape.size() - 1];
}
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise
: scale_tensor_alignment_Y_colwise;
const size_t alignment_X = is_rowwise
? scale_tensor_alignment_X_rowwise
: scale_tensor_alignment_X_colwise;
const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);
const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
}
} // namespace test
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <memory>
#include <vector>
#include <array>
#include <random>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
namespace test {
using namespace transformer_engine;
template <size_t i>
struct BytesToType {};
template <>
struct BytesToType<1> {
using Type = uint8_t;
};
template <>
struct BytesToType<2> {
using Type = uint16_t;
};
template <>
struct BytesToType<4> {
using Type = uint32_t;
};
template <>
struct BytesToType<8> {
using Type = uint64_t;
};
using byte = uint8_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t;
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int32,
int64,
fp32,
fp16,
bf16,
fp8e4m3,
fp8e5m2,
fp8e8m0>;
template <typename U, DType current>
struct Helper {
constexpr static DType getType() {
constexpr int i = static_cast<int>(current);
if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
return current;
} else {
return Helper<U, static_cast<DType>(i + 1)>::getType();
}
}
};
template <typename U>
struct Helper<U, DType::kNumTypes> {
constexpr static DType getType() {
return DType::kNumTypes;
}
};
template <typename U>
constexpr static DType getType() {
return Helper<U, DType::kByte>::getType();
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
};
class Tensor {
public:
Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING);
Tensor(const std::string& name,
const std::vector<size_t> &shape,
const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}
Tensor() {}
Tensor& operator=(const Tensor &other) = delete;
Tensor(const Tensor &other) = delete;
Tensor(Tensor &&other) = default;
Tensor& operator=(Tensor &&other) = default;
~Tensor() {
void *data_ptr = tensor_.dptr();
void *scale_inv = tensor_.scale_inv();
void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr;
void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr;
if (columnwise_data_ptr == data_ptr) {
columnwise_data_ptr = nullptr;
}
if (columnwise_scale_inv == scale_inv) {
columnwise_scale_inv = nullptr;
}
if (data_ptr != nullptr) {
cudaFree(data_ptr);
}
if (scale_inv != nullptr) {
cudaFree(scale_inv);
}
if (columnwise_data_ptr != nullptr){
cudaFree(columnwise_data_ptr);
}
if (columnwise_scale_inv != nullptr){
cudaFree(columnwise_scale_inv);
}
}
NVTETensor data() const noexcept {
return tensor_.data();
}
NVTEShape rowwise_shape() const noexcept {
return tensor_.get_rowwise_data().shape;
}
NVTEShape columnwise_shape() const noexcept {
return tensor_.get_columnwise_data().shape;
}
NVTEShape rowwise_scale_inv_shape() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
return tensor_.get_rowwise_scale_inv().shape;
}
NVTEShape columnwise_scale_inv_shape() const {
NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
return tensor_.get_columnwise_scale_inv().shape;
}
NVTEScalingMode scaling_mode() const noexcept {
return tensor_.scaling_mode();
}
DType dtype() const noexcept {
return tensor_.dtype();
}
void *rowwise_dptr() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
return tensor_.get_rowwise_data().data_ptr;
}
void *columnwise_dptr() const {
NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
return tensor_.get_columnwise_data().data_ptr;
}
template <typename T>
T *rowwise_cpu_dptr() const {
NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
return reinterpret_cast<T *>(cpu_data_rowwise_.get());
}
template <typename T>
T *columnwise_cpu_dptr() const {
NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
return reinterpret_cast<T *>(cpu_data_columnwise_.get());
}
float amax() const {
if(amax_cpu_data_) {
to_cpu();
return *amax_cpu_data_;
} else {
return 0;
}
}
float scale() const {
if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!");
to_cpu();
return *scale_cpu_data_;
} else {
return 1;
}
}
template <typename T>
T *rowwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
to_cpu();
return reinterpret_cast<T*>(rowwise_scale_inv_cpu_data_.get());
}
template <typename T>
T *columnwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
to_cpu();
return reinterpret_cast<T*>(columnwise_scale_inv_cpu_data_.get());
}
float rowwise_scale_inv(){
if(rowwise_scale_inv_cpu_data_) {
float scale_inv = rowwise_cpu_scale_inv_ptr<float>()[0];
return scale_inv;
} else {
return 1;
}
}
bool rowwise() const {
return rowwise_;
}
bool columnwise() const {
return columnwise_;
}
void set_tensor_amax_nullptr(){
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}
void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
void set_scale_inv(float scale_inv);
void shareFP8Meta(const Tensor &other);
std::mt19937& gen() { return gen_; }
private:
TensorWrapper tensor_;
std::unique_ptr<unsigned char[]> cpu_data_rowwise_;
std::unique_ptr<unsigned char[]> cpu_data_columnwise_;
std::shared_ptr<float> amax_cpu_data_;
std::shared_ptr<float> scale_cpu_data_;
std::unique_ptr<unsigned char[]> rowwise_scale_inv_cpu_data_;
std::unique_ptr<unsigned char[]> columnwise_scale_inv_cpu_data_;
bool rowwise_;
bool columnwise_;
std::string name_;
std::mt19937 gen_;
};
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
}
inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) {
return divide_round_up(N, M) * M;
}
template <typename T>
struct Numeric_Traits {
static constexpr double minSubnorm = 1.0;
static constexpr double maxSubnorm = 1.0;
static constexpr double minNorm = 1.0;
static constexpr double maxNorm = 1.0;
static constexpr double artifInf = 1.0;
static constexpr int maxBiasedExponent = 1;
};
template <>
struct Numeric_Traits<fp8e4m3> {
static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0);
static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
static constexpr double maxNorm = 448.0;
static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity
static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS;
static constexpr int maxUnbiasedExponentAsFP32 = 8;
static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32;
};
template <>
struct Numeric_Traits<fp8e5m2> {
static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 16); // std::pow(2.0, -16.0);
static constexpr double maxSubnorm = 0.75 / static_cast<double>(1 << 14); // std::pow(2.0, -14.0);
static constexpr double minNorm = 1.0 / static_cast<double>(1 << 14); // std::pow(2.0, -14.0);
static constexpr double maxNorm = 57344.0;
static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity
static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS;
static constexpr int maxUnbiasedExponentAsFP32 = 15;
static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32;
};
template <>
struct Numeric_Traits<fp32> {
static constexpr double minSubnorm = std::numeric_limits<fp32>::denorm_min(); // std::pow(2.0, -149.0);
static constexpr double maxSubnorm = std::numeric_limits<fp32>::min()
- std::numeric_limits<fp32>::denorm_min(); // minNormalized - minDenormalized
static constexpr double minNorm = std::numeric_limits<fp32>::min(); // std::pow(2.0, -126.0);
static constexpr double maxNorm = std::numeric_limits<fp32>::max(); // (1 - pow(2, -24)) * pow(2, 128)
static constexpr double artifInf = std::numeric_limits<fp32>::infinity();
static constexpr int maxBiasedExponentAsFP32 = 255;
static constexpr int maxUnbiasedExponentAsFP32 = 128;
};
template <typename T>
struct Quantized_Limits {
static constexpr double ranges[] = {
0.0,
Numeric_Traits<T>::minNorm,
Numeric_Traits<T>::maxNorm,
Numeric_Traits<T>::artifInf
};
static constexpr inline fp32 max() { return static_cast<fp32>(Numeric_Traits<T>::maxNorm); }
static constexpr inline fp32 max_reciprocal() { return static_cast<fp32>(1.0 / max()); }
static constexpr inline fp32 emax() { return static_cast<fp32>(Numeric_Traits<T>::maxExpNorm); }
static constexpr inline fp32 emax_reciprocal() { return static_cast<fp32>(1.0 / emax()); }
static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits<T>::maxBiasedExponentAsFP32; }
static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits<T>::maxUnbiasedExponentAsFP32; }
};
// Input data filling cases
// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats
// with nearest to even rounding per OFP8 specification
enum InputsFillCase {
zero_to_minNorm = 0, // [0, min_normal)
minNorm_to_maxNorm = 1, // [min_normal, max_normal)
maxNorm_to_inf = 2, // [max_normal, inf)
zeros = 3, // {0}
uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0)
};
inline fp8e8m0 float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (std::isnan(val)) {
return 0xFF;
}
if (std::isinf(val)) {
return 0xFE;
}
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&val);
fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
}
inline float exp2f_rcp(fp8e8m0 biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
inline float identity(const float x) { return x; }
inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); }
inline float dgelu(const float x) {
const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x));
return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x))
+ 0.5f * (1 + tanh_out);
}
inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); }
inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); }
inline float qgelu(const float x) { return x * sigmoid(1.702f * x); }
inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); }
inline float relu(const float x) { return fmaxf(0, x); }
inline float drelu(const float x) { return x > 0 ? 1 : 0; }
inline float silu(const float x) { return x * sigmoid(x); }
inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type);
size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape);
size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref,
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true);
void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
std::pair<double, double> getTolerances(const DType type);
void fillUniform(Tensor *t);
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case);
void setRandomScale(Tensor *t);
void setRandomScaleInv(Tensor *t);
constexpr int THREADS_PER_WARP = 32;
const std::string &typeName(DType type);
const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t blackwellComputeCapability = 100;
} // namespace test
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kByte: \
{ \
using type = byte; \
{__VA_ARGS__} \
} \
break; \
case DType::kInt32: \
{ \
using type = int32; \
{__VA_ARGS__} \
} \
break; \
case DType::kInt64: \
{ \
using type = int64; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_executable(test_util
test_nvrtc.cpp
test_string.cpp
../test_common.cu)
find_package(OpenMP REQUIRED)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX)
target_compile_options(test_util PRIVATE -O2 -fopenmp)
include(GoogleTest)
gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600)
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <stdexcept>
#include <vector>
#include <gtest/gtest.h>
#include "util/rtc.h"
using namespace transformer_engine;
TEST(UtilTest, NVRTC) {
if (!rtc::is_enabled()) {
GTEST_SKIP() << "NVRTC not enabled, skipping tests";
}
// GPU data buffer
int *device_buffer;
std::vector<int> host_buffer(2);
cudaMalloc((void**)&device_buffer, 2*sizeof(int)); // NOLINT(*)
cudaMemset(device_buffer, 0, 2*sizeof(int));
// CUDA kernel implementations
const char code1[] = R"code(
#include <cuda_runtime.h>
__global__ void my_kernel(int2 *data) {
data->x = 123;
data->y = -456;
}
)code";
const char code2[] = R"code(
#include "utils.cuh"
__global__ void my_kernel(uint32_t *data) {
data[0] = 789;
data[1] = 12;
}
)code";
// Make sure kernels are not available
auto& nvrtc_manager = rtc::KernelManager::instance();
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
device_buffer),
std::runtime_error);
EXPECT_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0,
device_buffer),
std::runtime_error);
// Compile and run first kernel
EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel1",
"my_kernel",
code1,
"test_nvrtc_kernel1.cu"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
device_buffer));
EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
cudaMemcpyDeviceToHost),
cudaSuccess);
EXPECT_EQ(host_buffer[0], 123);
EXPECT_EQ(host_buffer[1], -456);
// Compile and run second kernel
EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel2",
"my_kernel",
code2,
"test_nvrtc_kernel2.cu"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0, device_buffer));
EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
cudaMemcpyDeviceToHost),
cudaSuccess);
EXPECT_EQ(host_buffer[0], 789);
EXPECT_EQ(host_buffer[1], 12);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "util/string.h"
using namespace transformer_engine;
TEST(UtilTest, ToStringLike) { // to_string_like
// Strings
using namespace std::string_literals;
EXPECT_EQ(to_string_like(std::string("")), "");
EXPECT_EQ(to_string_like(""), "");
EXPECT_EQ(to_string_like(std::string("Hello")), "Hello");
EXPECT_EQ(to_string_like("world!"), "world!");
EXPECT_EQ(to_string_like(" \0\n\\\t\"\' This is a weird C++ string"s),
" \0\n\\\t\"\' This is a weird C++ string"s);
EXPECT_EQ(to_string_like(" Here is a bizarre C string \n\\\t\"\'"),
" Here is a bizarre C string \n\\\t\"\'");
// Zero integer types
EXPECT_EQ(to_string_like(19), "19");
EXPECT_EQ(to_string_like(static_cast<char>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<short int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned short int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<long long int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned long long int>(0)), "0");
// Non-zero integer types
EXPECT_EQ(to_string_like(static_cast<char>(1)), "1");
EXPECT_EQ(to_string_like(static_cast<char>(-1)), "-1");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(2)), "2");
EXPECT_EQ(to_string_like(static_cast<short>(3)), "3");
EXPECT_EQ(to_string_like(static_cast<short>(-5)), "-5");
EXPECT_EQ(to_string_like(static_cast<unsigned short>(8)), "8");
EXPECT_EQ(to_string_like(static_cast<int>(13)), "13");
EXPECT_EQ(to_string_like(static_cast<int>(-21)), "-21");
EXPECT_EQ(to_string_like(static_cast<unsigned int>(34)), "34");
EXPECT_EQ(to_string_like(static_cast<long long>(55)), "55");
EXPECT_EQ(to_string_like(static_cast<long long>(-89)), "-89");
EXPECT_EQ(to_string_like(static_cast<unsigned long long>(144)), "144");
EXPECT_EQ(to_string_like(static_cast<size_t>(233)), "233");
// Floating-point types
EXPECT_EQ(std::stof(to_string_like(0.f)), 0.f);
EXPECT_EQ(std::stod(to_string_like(0.)), 0.);
EXPECT_EQ(std::stof(to_string_like(1.25f)), 1.25f);
EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f);
EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25);
EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5);
// Container types
EXPECT_EQ(to_string_like(std::vector<int>{-3,1,-4}), "(-3,1,-4)");
EXPECT_EQ(to_string_like(std::vector<std::string>{"Accept", "no", "substitutes", ".",
"Buy", "N", "V", "IDIA"}),
"(Accept,no,substitutes,.,Buy,N,V,IDIA)");
}
TEST(UtilTest, ConcatStringsTest) { // concat_strings
// Strings
EXPECT_EQ(concat_strings(), "");
EXPECT_EQ(concat_strings(std::string("")), "");
EXPECT_EQ(concat_strings(""), "");
EXPECT_EQ(concat_strings(std::string(""), "", std::string(""), ""), "");
EXPECT_EQ(concat_strings("C string"), "C string");
EXPECT_EQ(concat_strings(std::string("C++ string")), "C++ string");
EXPECT_EQ(concat_strings("C string ", std::string("and"),
" ", std::string("C++ string")),
"C string and C++ string");
// Numeric types
EXPECT_EQ(concat_strings("int ", static_cast<int>(-123),
", uint ", static_cast<unsigned int>(456)),
"int -123, uint 456");
EXPECT_EQ(concat_strings("char ", static_cast<char>(13),
", uchar ", static_cast<unsigned char>(24)),
"char 13, uchar 24");
EXPECT_EQ(concat_strings("int16 ", static_cast<short>(-35),
", uint16 ", static_cast<unsigned short>(46)),
"int16 -35, uint16 46");
EXPECT_EQ(concat_strings("int64 ", static_cast<long long>(57),
", uint64 ", static_cast<unsigned long long>(68)),
"int64 57, uint64 68");
EXPECT_EQ(std::stof(concat_strings("-", 3.25f)), -3.25f);
EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f);
EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25);
EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5);
// Container types
EXPECT_EQ(concat_strings("vector ", std::vector<int>{1,-2,3}), "vector (1,-2,3)");
}
TEST(UtilTest, RegexReplaceTest) { // regex_replace
EXPECT_EQ(regex_replace("this test FAILS", "FAILS", "PASSES"),
"this test PASSES");
EXPECT_EQ(regex_replace("status = 0000", "0", 1), "status = 1111");
EXPECT_EQ(regex_replace("I um sound um \t very umconfident", R"(um\s*)", ""),
"I sound very confident");
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""conftest for tests/jax"""
import os
import jax
import pytest
import transformer_engine.jax
from transformer_engine_jax import get_device_compute_capability
@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import operator
import re
from functools import reduce
from itertools import product
import pytest
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
configs.append(
pytest.param(
4,
(2, 2),
("dp", "tp"),
MeshResource(dp_resource="dp", tp_resource="tp"),
id=f"n4_dp2_tp2",
)
)
return configs
def generate_context_parallel_configs():
configs = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
return configs
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
def generate_collectives_count(allreduce, allgather, other):
return {COLL_AR_KEY: allreduce, COLL_AG_KEY: allgather, COLL_OTHER_KEY: other}
def assert_equal_collectives(target_hlo, coll_count_ref):
target_splitted_hlo = target_hlo.splitlines()
start_symb = "-start"
def count_bytes(hlo_text):
bytes_count = 0
def get_bytes_per_txt(t):
"""
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
'f32[1024]{0})',
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
"""
match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == "":
num_of_elements = 1
else:
num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if "(" in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
if ")" in txt:
break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count
def count_collectives(splitted_hlo):
result = generate_collectives_count(0, 0, 0)
for line in splitted_hlo:
txt = line.split()
if len(txt) > 0 and start_symb in txt[0]:
if COLL_AR_KEY in txt[0]:
result[COLL_AR_KEY] += count_bytes(txt)
elif COLL_AG_KEY in txt[0]:
result[COLL_AG_KEY] += count_bytes(txt)
else:
result[COLL_OTHER_KEY] += count_bytes(txt)
return result
target_result = count_collectives(target_splitted_hlo)
assert (
target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(
target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs,
):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
metric_fwd_dtype = inputs[0].dtype
if metric_bwd_dtype is None:
metric_bwd_dtype = inputs[0].dtype
if grad_args is None:
grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
for i in range(len(target_grads)):
assert_allclose(target_grads[i], ref_grads[i], dtype=metric_bwd_dtype)
assert_equal_collectives(target_hlo, coll_count_ref)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[pytest]
filterwarnings=
ignore:Fused attention is not enabled.*:UserWarning
ignore:The hookimpl.*:DeprecationWarning
ignore:xmap is an experimental feature and probably has bugs!
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
ignore:ml_dtypes.float8_e4m3b11 is deprecated.
ignore:np.find_common_type is deprecated.*:DeprecationWarning
ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning
ignore:The numpy.array_api submodule is still experimental.*:UserWarning
ignore:case not machine-readable.*:UserWarning
ignore:not machine-readable.*:UserWarning
ignore:Special cases found for .* but none were parsed.*:UserWarning
ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from contextlib import nullcontext
from typing import Callable, List, Sequence, Union
import os
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose, assert_tree_like_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
_jax_dbias_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
(2048, 1024, 2048),
(2048, 2048, 1024),
(2048, 1024, 1024),
]
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
class TestFP8Dot:
@staticmethod
def _generate_fp8_meta():
fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
amax_list = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
scale_list = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
return fp8_dtype_list, amax_list, scale_list
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_qdq(self):
FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
amax = jnp.max(jnp.abs(x)).reshape(1)
scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
scale_inv = (1 / scale).reshape(1)
y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
primitive_out = type_safe_dot_general(a, b)
ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_fp8_randint(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
dtype = jnp.bfloat16
# TODO(rewang): add float random test
min_val, max_val = -8, 8
a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)
_, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
fp8_meta_pkg = FP8MetaPackage(
amax_list[0],
scale_list[0],
amax_list[1],
scale_list[1],
amax_list[2],
scale_list[2],
)
primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
ref_out = jnp.dot(a, b)
ref_out = ref_out.astype(jnp.float32)
primitive_out = primitive_out.astype(jnp.float32)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
def primitive_func(x, y):
primitive_out = type_safe_dot_general(x, y)
return jnp.mean(primitive_out)
def ref_func(x, y):
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_fp8_dot(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
_, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
def primitive_func(x, y, amax_list, scale_list):
fp8_meta_pkg = FP8MetaPackage(
amax_list[0],
scale_list[0],
amax_list[1],
scale_list[1],
amax_list[2],
scale_list[2],
)
primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
return jnp.mean(primitive_out)
def ref_func(x, y):
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
value_n_grad_primitive_func(a, b, amax_list, scale_list)
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
"m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
)
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_grad_fused_layernorm_fp8_mlp(
self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
):
"""N/a"""
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
b1 = None
b2 = None
def primitive_func(
x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg_1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg_2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
return jnp.mean(
fused_layernorm_fp8_mlp(
x,
ln_s,
None,
[y, z],
[w, v],
[fp8_meta_pkg_1, fp8_meta_pkg_2],
"rmsnorm",
activation_type=activation_type,
use_bias=use_bias,
)
)
def layernorm_fp8_mlp_ref(
x: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_meta_pkg_1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type)
fp8_meta_pkg_2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
if use_bias:
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
return jnp.mean(
layernorm_fp8_mlp_ref(
x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
)
)
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
)
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
_, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
_, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()
ref_amax_list_1 = amax_list_1
ref_scale_list_1 = scale_list_1
ref_amax_list_2 = amax_list_2
ref_scale_list_2 = scale_list_2
primitive_amax_list_1 = amax_list_1
primitive_scale_list_1 = scale_list_1
primitive_amax_list_2 = amax_list_2
primitive_scale_list_2 = scale_list_2
primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3):
ref_out, (
ref_a_grad,
ref_s_grad,
ref_k1_grad,
ref_k2_grad,
ref_b1_grad,
ref_b2_grad,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
) = value_n_grad_ref_func(
a,
s,
k1,
k2,
b1,
b2,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
)
for _ in range(3):
primitive_out, (
primitive_a_grad,
primitive_s_grad,
primitive_k1_grad,
primitive_k2_grad,
primitive_b1_grad,
primitive_b2_grad,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
) = value_n_grad_primitive_func(
a,
s,
k1,
k2,
b1,
b2,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(
jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
if use_bias:
assert_allclose(
jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
return out
class TestActivationLu:
def ref_func(self, x, activation_type):
def ref_act_lu(inputs):
x = _jax_act_lu(inputs, activation_type)
return jnp.mean(x)
ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
return ref_act_func(x)
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestActivationLuFP8(TestActivationLu):
def prim_func(self, x):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
activation_type = self.activation_type
@jax.custom_vjp
def _prim_func(x, _x_t, _dbias, _amax):
output = _prim_func_fwd(x, _x_t, _dbias, _amax)
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = tex.act_lu_fp8(
x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = x
return activation_lu_out, ctx
def _prim_func_bwd(ctx, g):
x = ctx
if len(self.activation_type) > 1: # gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: # not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
tex.dact_lu_dbias_cast_transpose(
g,
x,
amax,
scale,
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
self.activation_type,
)
)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
return ctx
_prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(
lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
)
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-2)
axes = jnp.arange(x.ndim)
self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]])
print(self.transpose_axes)
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
if "linear" not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(
prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_axes),
dtype=FP8Helper.BWD_DTYPE,
)
class TestNorm:
"""
Test transformer_engine.jax.layernorm APIs
"""
@staticmethod
def _generate_fp8_meta():
fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
amax_list = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
scale_list = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
return fp8_dtype_list, amax_list, scale_list
def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
"""
JAX native layernorm implementations
- bias is not None: layernorm
- bias is None: rmsnorm
"""
x_ = jnp.asarray(x, jnp.float32)
if bias is None:
mean = 0.0
else:
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
scale += 1.0
if bias is None:
bias = 0.0
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize("n, hidden", LN_CASES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_layernorm_forward_backward(
self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
expect_assert = False
if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with (
pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype)
if ln_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype)
else:
beta = None
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
),
(0, 1, 2),
)
)
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
),
(0, 1, 2),
)
)
primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
x, gamma, beta
)
reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
x, gamma, beta
)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
if beta is not None:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("zero_centered_gamma", [True, False])
@pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert = False
if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with (
pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else:
beta = None
_, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()
def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
fp8_meta_pkg = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
primitive_out = layernorm_fp8_dot(
x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
)
return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma):
x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
)
for _ in range(3):
primitive_out, (
primitive_a_grad,
primitive_b_grad,
primitive_gamma_grad,
primitive_beta_grad,
amax_list_1,
scale_list_1,
) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"input_shape, transpose_axis",
[
pytest.param((16, 16), 1, id="(16, 16)-1"),
pytest.param((256, 128), 1, id="(256, 128)-1"),
pytest.param((128, 512), 1, id="(128, 512)-1"),
pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
],
)
class TestTranspose:
def test_transpose(self, in_dtype, input_shape, transpose_axis):
key = jax.random.PRNGKey(0)
input_tensor = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
assert_allclose(jax_output, noffi_output)
assert_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_cast_transpose(
input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_dbias_cast_transpose(
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
"input_shape",
[
pytest.param((256, 128), id="(256, 128)"),
pytest.param((128, 512, 8), id="(128, 512, 8)"),
],
)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_quantize(input_shape, in_dtype, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
)
DTYPES = [jnp.bfloat16]
class TestDistributedSelfAttn:
def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
dropout_prob = 0.0
is_training = True
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
col_ref = self.generate_collectives_count_ref(
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type != AttnBiasType.NO_BIAS,
data_shape,
dtype,
)
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_head,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
QKVLayout.BS3HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)
runner.test_backward()
class TestDistributedCrossAttn:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
col_ref = self.generate_collectives_count_ref()
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_head,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
QKVLayout.BSHD_BS2HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)
runner.test_backward()
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
class TestDistributedContextParallelSelfAttn:
def impl_test_context_parallel_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = qkv_layout.get_qkv_format()
batch, seqlen, num_head, hidden = data_shape
# Scale the sequence length by 2*CP so its never too small as we scale up test.
# 2*CP is used since we split into two CP groups for load balancing.
seqlen = seqlen * cp_size * 2
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_kv_heads,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
None,
SeqDescFormat.SegmentIDs,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
cp_strategy=cp_strategy,
cp_load_balanced=load_balanced,
)
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP
# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)
if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
runner.test_backward()
def test_context_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)
@pytest.mark.parametrize(
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
)
def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
use_scan,
):
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize(
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
assert jnp.array_equal(inversed, ref)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import warnings
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.layernorm import layernorm
DTYPES = [jnp.bfloat16, jnp.float32]
class TestDistributedLayernorm:
def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
weight_shape = (shape[-1],)
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
gamma = jnp.ones(weight_shape, dtype=dtype)
beta = jnp.ones(weight_shape, dtype=dtype)
if len(shape) == 2:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
elif len(shape) == 3:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
else:
raise NotImplementedError
g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6
ln_type = "layernorm"
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
def ref_func(x, gamma, beta):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
if zero_centered_gamma:
output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
else:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"Layernorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6
ln_type = "rmsnorm"
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
def ref_func(x, gamma):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
output = y * gamma
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma. We can catch
# and ignore that specific error here.
if g_pspec[-1] is None or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"RmsNorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from typing import Callable, List, Sequence, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES,
SEQLEN_TP_AXES,
SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
INTERMEDIATE = 16
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
return configs
class TestDistributedLayernormMLP:
def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
batch, seqlen, hidden_in = input_shape
hidden_out = hidden_in
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else:
b1 = None
b2 = None
return (x, gamma, k1, k2, b1, b2)
def layernorm_fp8_mlp_prim_func(
self,
x: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES
else:
layernorm_input_axes = None
dot_1_input_axes = None
dot_2_input_axes = None
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
fused_layernorm_fp8_mlp(
x,
ln_scale,
None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm"
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU
single_jitter = jax.jit(
value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
)
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype)
for i in range(len(inputs)):
if multi_grads[i] is not None:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
# Make sure params values are the same
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import warnings
import pytest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSoftmax:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(batch, sqelen)
if not bad_sharding:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
)
else:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None
if mask is not None:
bias = jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.0).astype(dtype),
)
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
):
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(
target_func,
ref_func,
[x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
)
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
# when XLA is forced to unshard the hidden dimension. We can catch
# and ignore that specific error here.
if not bad_sharding or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope
class TestLoRA:
def reference(x, la, lb, pattern, scale):
out = jnp.einsum(pattern, x, la, lb)
return out * scale
@pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize(
"axis_features_pattern",
[((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
)
@pytest.mark.parametrize("rank", [32, 16])
@pytest.mark.parametrize("alpha", [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
shape_in_axis = tuple(shape[ax] for ax in axis)
key = jax.random.key(1124)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, shape, dtype)
key, la_key = jax.random.split(key)
la_shape = (*shape_in_axis, *features[:-1], rank)
la = jax.random.normal(la_key, la_shape, dtype)
key, lb_key = jax.random.split(key)
lb_shape = (*features[:-1], rank, features[-1])
lb = jax.random.normal(lb_key, lb_shape, dtype)
out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
scale_ref = alpha / rank if alpha is not None else 1.0
out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)
assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize(
"scope_ref_assert",
[
("none", LoRAScope(False, False, False), False),
("all", LoRAScope(True, True, True), False),
("qkv_proj", LoRAScope(True, False, False), False),
("output_proj", LoRAScope(False, True, False), False),
("mlp", LoRAScope(False, False, True), False),
("exclude_qkv_proj", LoRAScope(False, True, True), False),
("exclude_output_proj", LoRAScope(True, False, True), False),
("exclude_mlp", LoRAScope(True, True, False), False),
("messing_up", LoRAScope(), True),
],
)
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
lora_scope = _canonicalize_lora_scope(scope)
assert lora_scope == reference
except AssertionError as ae:
assert need_assert, f"{ae.args}"
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
from typing import Tuple, Optional, Dict
import random
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from flax.linen import combine_masks
from flax.linen import make_attention_mask
from flax.linen.dtypes import promote_dtype
from jax import Array
from jax import value_and_grad, jit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
)
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
@pytest.fixture(autouse=True, scope="module")
def init():
"""
WAR for CUDA uninitialize error
"""
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
dtype: DTypeLike,
) -> Array:
"""
Similar to flax.linen.dot_product_attention but with GQA support
"""
query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
dtype = query.dtype
b, s_q, h_q, d = query.shape
_, s_kv, h_kv, _ = key.shape
assert (h_q % h_kv == 0) and (h_q >= h_kv)
num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
if bias is not None:
# reshape logits without groups
logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
# apply post-scale bias
logits = logits + bias
# reshape logits back to original
logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
if mask is not None:
if mask.ndim != logits.ndim:
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape)
return context
@jax.jit
def make_causal_mask(
segment_ids_q: ArrayLike,
segment_ids_kv: ArrayLike,
segment_pos_q: ArrayLike = None,
segment_pos_kv: ArrayLike = None,
) -> Array:
"""
Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position.
If segment_pos is not provided, aragne of the segment_ids will be applied.
"""
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
return inv_causal_mask
@partial(jax.jit, static_argnums=(4, 5))
def make_mask(
segment_ids_q: ArrayLike,
segment_ids_kv: ArrayLike,
segment_pos_q: ArrayLike,
segment_pos_kv: ArrayLike,
attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None,
) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing
that position to participate in attention.
- segment_ids should start with 1, and using 0s for the paddings.
Expected that each segment starts without paddings.
- segment_pos marks the token position in the segments.
A example pair of segments_ids and segment_pos:
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
"""
# segment masks
inv_mask = make_attention_mask(
segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
# causal mask
if attn_mask_type.is_causal():
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask
@jax.jit
def get_seqlens_and_offsets(segment_ids):
batch, max_seqlen = segment_ids.shape
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
seqlens = seqlens_with_zero[..., 1:]
def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = x[..., :1] != 0
same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
same_as_previous
).squeeze(-1)
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
seqlens = jnp.where(seqlens, seqlens, -1)
return seqlens, offsets
@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
"""Use JIT to speed up the verifications"""
primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
output = general_dot_product_attention(
query,
key,
value,
bias,
mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
dropout_rng=dropout_rng,
dtype=jnp.float32,
)
return output.astype(query.dtype)
def customcall_fused_dpa(
query,
key,
value,
bias,
sequence_descriptor,
dropout_rng,
**kwargs,
):
"""
TE customcall dot product attention implementation
"""
qkv_layout = kwargs["qkv_layout"]
match qkv_layout:
case QKVLayout.BS3HD | QKVLayout.T3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
qkv_args = (qkv,)
case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
qkv_args = (query, kv)
case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""
_1HSS = "1HSS"
_B1SS = "B1SS"
_BHSS = "BHSS"
_11SS = "11SS"
class SeqDescFormat(Enum):
Mask = auto()
Seqlens = auto()
SegmentIDs = auto()
@dataclass
class FusedAttnRunner:
"""
Fused attention runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
num_heads_q: int
num_heads_kv: int
head_dim: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_prob: float
dtype: DTypeLike
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
mesh_shape: tuple[int, ...] = (1, 1, 1)
mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))
# Context parallel aux arguments
cp_strategy: CPStrategy = CPStrategy.DEFAULT
cp_load_balanced: bool = True
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
else:
return 1
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
if self.qkv_layout.is_qkvpacked():
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
if self.num_heads_q != self.num_heads_kv:
pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape._1HSS
):
if self.attn_mask_type.is_padding():
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
)
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend."
)
def _setup_inputs(self):
self._check_configs()
# Create a mesh for distributed tests
self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape._1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape._B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape._BHSS:
bias_shape = (
self.batch_size,
self.num_heads_q,
self.max_seqlen_q,
self.max_seqlen_kv,
)
elif self.bias_shape == BiasShape._11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape._1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
self.bias = self.bias.at[
:, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
].set(0.0)
else:
self.bias = None
if self.attn_mask_type.is_padding():
pad_ratio = 0.3
else:
pad_ratio = 0.0
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return tokens, jnp.logical_not(tokens)
def generate_random_segment_ids(
batch_size,
sequence_length,
num_segments,
seed,
with_segment_pad=True,
min_segment_len=None,
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
# Not include paddings
max_segment_size = sequence_length // num_segments
for i in range(batch_size):
current_pos = 0
segment_id = 1
for seg_id in range(num_segments):
# min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
# TODO(rewang): Remove this constrain after cuDNN supports
min_segment_size = 1
if min_segment_len is not None:
min_segment_size = min_segment_len[i][seg_id]
segment_size = rng.integers(min_segment_size, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad:
num_valid = rng.integers(min_segment_size, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
segment_pad[i, current_pos:sequence_length] = 1
segment_ids, segment_pos, segment_pad = map(
jnp.asarray, [segment_ids, segment_pos, segment_pad]
)
segment_ids = jnp.where(segment_pad, 0, segment_ids)
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
self.batch_size, self.max_seqlen_q, pad_ratio
)
self.segment_ids_kv, self.pad_kv = gen_valid(
self.batch_size, self.max_seqlen_kv, pad_ratio
)
self.segment_pos_q = self.segment_pos_kv = None
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
# For reference code
self.mask = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
self.window_size,
)
if self.cp_size > 1 and self.cp_load_balanced:
if self.qkv_layout.is_thd():
reorder_strategy = ReorderStrategy.Striped
else:
reorder_strategy = ReorderStrategy.DualChunkSwap
seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
# Test different input formats
if self.qkv_layout.is_thd():
match self.seq_desc_format:
case SeqDescFormat.Mask:
pytest.skip("THD doesn't support mask input")
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
(self.seqlens_q, self.seqlens_kv),
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
else:
match self.seq_desc_format:
case SeqDescFormat.Mask:
if self.attn_mask_type == AttnMaskType.NO_MASK:
self.sequence_desciptor = None
else:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
# Setup distributed sharding specs
# Setup shardings for distributed tests
self.qkvo_psec = PartitionSpec(
self.mesh_resource.dp_resource,
self.mesh_resource.cp_resource,
self.mesh_resource.tp_resource,
None,
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, mask_pspec)
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.seq_desc_sharding = self.mask_sharding
case _:
def to_dp_shardings(x):
if x.ndim == 1:
pspec = PartitionSpec(self.mesh_resource.dp_resource)
else:
pspec = PartitionSpec(
self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
)
return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._B1SS:
self.bias_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._11SS:
self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
else:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)
self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)
# [batch][max_segments_per_batch]
# TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
def test_forward(self):
"""
Test forward without JIT
"""
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
# Put test data onto each GPU for distributed.
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
customcall_fused_dpa_jit = jit(
partial(customcall_fused_dpa, **kwargs),
static_argnames=kwargs.keys(),
in_shardings=[
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)
reference_out = jax_dpa(*args, **kwargs)
if self.is_training and self.dropout_prob > 0.0:
return
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, self.coll_count_ref)
def test_backward(self):
"""
Test value_and_grad with JIT, which includes both forward and backward.
If coll_count_ref is not None then the HLO of the backwrds function
HLO will be examined for the expected comms.
"""
self._setup_inputs()
def grad_func(func, *args, cp_reverse_out=False, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
if not cp_reverse_out:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
func(*args, **kwargs),
)
else:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
self.cp_inverse_reorder_fn(func(*args, **kwargs)),
)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
# We can compute dBias only for the [1, h, s, s] layout
if self.bias_shape == BiasShape._1HSS:
arg_nums = (0, 1, 2, 3)
grad_shardings = (
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
)
else:
arg_nums = (0, 1, 2)
grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
),
arg_nums,
),
in_shardings=(
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
)
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums,
)
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.0:
return
print_debug_tensor_stats(f"primitive_out", primitive_out)
print_debug_tensor_stats(f"reference_grad_valid", reference_out)
print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
def check_dqkv(primitive, reference, pad, idx):
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive, reference, pad)
)
print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
print_debug_tensor_stats(
f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)
check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
# TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.
primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3]
# Assume all batch has the same actual_seqlen, probably needs to extend the tests
bias_mask = self.mask[0, 0]
# Assert all masked dbias are 0s
assert_allclose(
jnp.where(bias_mask, primitive_dbias, 0),
jnp.zeros_like(primitive_dbias),
dtype=self.dtype,
)
# dbias padded part
assert_allclose(
jnp.where(bias_mask, primitive_dbias, 0),
jnp.where(bias_mask, reference_dbias, 0),
dtype=self.dtype,
)
# dbias valid part
assert_allclose(
jnp.where(bias_mask, 0, primitive_dbias),
jnp.where(bias_mask, 0, reference_dbias),
dtype=self.dtype,
)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(
2,
2048,
1024,
12,
12,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
pytest.param(True, id="SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
pytest.param(False, id="INFERENCE"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_forward()
@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
True,
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_backward()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import flax
import jax
import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
class TestFP8Helper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10
FP8Helper.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
FP8Helper.MARGIN,
margin,
f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.",
)
self.assertEqual(
FP8Helper.FP8_FORMAT,
fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.",
)
self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN,
amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
)
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
original_state = {
"test1": original_val,
"test2": original_val,
}
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
self.assertFalse(FP8Helper.is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
(MeshResource("dp", None)),
(MeshResource(None, "tp")),
(MeshResource("dp", "tp")),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
self._check_defult_state()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict, Tuple, Optional
import flax
import jax
import jax.numpy as jnp
import pytest
from utils import (
assert_allclose,
assert_tree_like_allclose,
dtype_tols,
sync_params_values,
)
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope="function")
def enable_fused_attn():
"""Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]
DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id="32-128-1024"),
pytest.param((32, 512, 1024), id="32-512-1024"),
]
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_HIDDEN_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0.0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_WINDOW_SIZE: (-1, -1),
}
ATTRS = [
{},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
},
{
_KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2,
},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True,
},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_MLP_ACTIVATIONS: ("gelu",),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu",)),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
},
{
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
},
{
_KEY_OF_ATTENTION_DROPOUT: 0.3,
},
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class BaseRunner:
"""Base runner to define forward and backward tests"""
layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None
def __init__(self, attrs):
self.attrs = attrs
self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
os.environ["NVTE_FUSED_ATTN"] = "0"
def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
self.apply_rng = {"dropout": apply_dropout_rng}
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, "params")
del variables
return layer, params, others
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
variables = {"params": params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
def _sync_params(self, ref, target):
"""Copy the reference params to target"""
target = sync_params_values(target, ref, self.transformations)
return ref, target
def test_forward(
self,
data_shape: Tuple[int],
dtype: jnp.dtype,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
ref_params, test_params = self._sync_params(ref_params, test_params)
ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
assert_allclose(ref_out, test_out, **tols)
def test_backward(
self,
data_shape: Tuple[int],
dtype: jnp.dtype,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
ref_params, test_params = self._sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
test_masks,
test_params,
test_others,
test_layer,
)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
inputs, ref_masks, ref_params, ref_others, ref_layer
)
test_out, (test_dgrads, test_wgrads) = grad_fn(
inputs, test_masks, test_params, test_others, test_layer
)
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
assert_allclose(ref_out, test_out, **tols)
assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
_, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
class EncoderRunner(BaseRunner):
"""Encoder runner implementations"""
layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer
transformations = {
"attention/qkv/scale": "pre_attention_layer_norm/scale",
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
"mlp/wo_bias": "mlp/wo/bias",
"mlp/scale": "pre_mlp_layer_norm/scale",
"mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
"""
Return inputs, (ref_masks, test_masks)
"""
transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
return inputs, (ref_masks, test_masks)
class DecoderRunner(BaseRunner):
"""
Decoder runner implementations
"""
layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer
transformations = {
"encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
"mlp/wo_bias": "mlp/wo/bias",
"mlp/scale": "pre_mlp_layer_norm/scale",
"mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
"""
Return inputs, (ref_masks, test_masks)
"""
transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
inputs = (
jax.random.normal(data_rng_0, data_shape, dtype),
jax.random.normal(data_rng_1, data_shape, dtype),
)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
self_mask = causal_mask
else:
self_mask = padded_mask
ref_masks = (1 - self_mask, 1 - padded_mask)
test_masks = (self_mask, padded_mask)
return inputs, (ref_masks, test_masks)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
"""
Pytest interface to invoke the runner
"""
runner = BaseRunner
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
class TestEncoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
"""
runner = EncoderRunner
class TestDecoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
"""
runner = DecoderRunner
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from functools import partial
import os
from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags
def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None
os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None
os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
from functools import partial
from typing import Dict, Tuple
import flax
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, f"{key} not found in test dict {test_fd}"
assert isinstance(
test_fd[key], type(ref_fd[key])
), f"The data type is not match between ref and test Dict on {key=}"
if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(
ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
)
class TestLayer:
@staticmethod
def loss(inner_variables, *inner_inputs, module, mean_out=True):
outs = module.apply(inner_variables, *inner_inputs)
out = outs
if isinstance(outs, tuple):
# The first place of outs is the real output, others
# are auxiliary values.
out = outs[0]
return jnp.mean(out) if mean_out else out
@staticmethod
def loss_and_grads(module, variables, *inputs):
grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
return loss_val, wgrads, dgrad
def input_getter(self, shape, dtype):
raise NotImplementedError
def get_layer_name(self):
raise NotImplementedError
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
raise NotImplementedError
def sync_variables(self, praxis_variables, flax_variables):
synced_praxis_variables = praxis_variables
lyr_name = self.get_layer_name()
if "params" in flax_variables:
synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
flax_variables["params"]
)
return synced_praxis_variables, flax_variables
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
synced_praxis_grads = praxis_wgrads
lyr_name = self.get_layer_name()
if "params" in synced_praxis_grads:
synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
if FP8Helper.is_fp8_enabled():
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
FP8Helper.FP8_COLLECTION_NAME
][lyr_name]["cld"]
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(
self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
):
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype)
praxis_layer = praxis_p.Instantiate()
# This is a workaround to correctly enable FP8 meta generation for Praxis.
# TODO (Ming Huang): To come out a better solution.
mutable_list = DEFAULT_INIT_MUTABLE_LIST + [FP8Helper.FP8_COLLECTION_NAME]
praxis_variables = praxis_layer.init(init_key, *test_inputs, mutable=mutable_list)
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_inputs)
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(
flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times):
praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
praxis_layer, praxis_variables, *test_inputs
)
flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop("params")
praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
praxis_layer, praxis_variables, *test_inputs
)
flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
praxis_wgrads, flax_wgrads = self.sync_wgrads(praxis_wgrads, flax_wgrads)
compare_dict(praxis_wgrads, flax_wgrads, rtol=rtol, atol=atol)
class LayerNormAttr:
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ATTRS = [
{LN_TYPE: "layernorm", ZERO_CEN: False},
{LN_TYPE: "layernorm", ZERO_CEN: True},
{LN_TYPE: "rmsnorm", ZERO_CEN: False},
]
class TestLayerNorm(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return "layer_norm"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormAttr.ZERO_CEN]
scale_init = None
bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(
LayerNorm,
name="layer_norm",
dtype=dtype,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_LayerNorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr:
SCALE_FACTOR = "scale_factor"
ST_TYPE = "softmax_type"
ATTRS = [
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
]
class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
praxis_p = pax_fiddle.Config(
FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
)
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls
def sync_variables(self, praxis_variables, flax_variables):
return praxis_variables, flax_variables
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads
@pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
data_shape[-2] != data_shape[-1]
):
pass # Skip, due to not support
else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LinearAttr:
FEATURE = "features"
USE_BIAS = "use_bias"
ATTRS = [
{FEATURE: 512, USE_BIAS: False},
{FEATURE: 512, USE_BIAS: True},
{FEATURE: 1024, USE_BIAS: False},
{FEATURE: 1024, USE_BIAS: True},
]
class TestLinear(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return "linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LinearAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(
Linear,
name="linear",
dtype=dtype,
out_features=out_features,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
DenseGeneral,
features=out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LayerNormLinearAttr:
FEATURE = "features"
USE_BIAS = "use_bias"
ENABLE_LN = "enable_layernorm"
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ATTRS = [
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
]
class TestLayerNormLinear(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return "ln_linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE]
enable_layernorm = attrs[LayerNormLinearAttr.ENABLE_LN]
layernorm_type = attrs[LayerNormLinearAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormLinearAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LayerNormLinearAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(
LayerNormLinear,
name="ln_linear",
dtype=dtype,
out_features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
LayerNormDenseGeneral,
features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LayerNormMLPAttr:
INTERMEDIATE_DIM = "intermediate_dim"
USE_BIAS = "use_bias"
ENABLE_LN = "enable_layernorm"
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ACTIVATION = "activations"
ATTRS = [
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("silu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("silu", "linear"),
},
]
class TestLayerNormMLP(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return "ln_mlp"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
enable_layernorm = attrs[LayerNormMLPAttr.ENABLE_LN]
layernorm_type = attrs[LayerNormMLPAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormMLPAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LayerNormMLPAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
activations = attrs[LayerNormMLPAttr.ACTIVATION]
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(
LayerNormMLP,
name="ln_mlp",
dtype=dtype,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_LayerNormMLP,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class TestRelativePositionBias(TestLayer):
def get_layer_name(self):
return "relative_position_bias"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32
max_distance = 128
num_attention_heads = 64
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
praxis_p = pax_fiddle.Config(
RelativePositionBiases,
name="relative_position_bias",
dtype=dtype,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=embedding_init,
)
flax_cls = partial(
flax_RelativePositionBiases,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
dtype=dtype,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = [(128, 128, True), (128, 128, False)]
for test_input in test_inputs:
praxis_layer = praxis_p.Instantiate()
praxis_variables = praxis_layer.init(init_key, *test_input)
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_input)
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(
flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss = TestLayer.loss(
praxis_variables, *test_input, module=praxis_layer, mean_out=False
)
flax_loss = TestLayer.loss(
flax_variables, *test_input, module=flax_layer, mean_out=False
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
ATTN_MASK_TYPE = "attn_mask_type"
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = "scale_factor"
WINDOW_SIZE = "window_size"
ATTRS = [
{
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.0,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "no_mask",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
mask,
]
def get_layer_name(self):
return "dot_product_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
window_size = attrs.get(DotProductAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config(
DotProductAttention,
name="mha",
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class MultiHeadAttnAttr:
USE_BIAS = "use_bias"
LN_TYPE = "layernorm_type"
ATTN_MASK_TYPE = "attn_mask_type"
ZERO_CEN = "zero_centered_gamma"
NUM_ATTN_HEADS = "num_attention_heads"
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
LORA_SCOPE: "all",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self):
return "multi_head_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = (
attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
else None
)
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
input_layernorm = False
return_layernorm_output = False
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
fuse_qkv_params = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
window_size = attrs.get(MultiHeadAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
name="mha",
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)
flax_cls = partial(
flax_MultiHeadAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class TransformerLayerAttr:
USE_BIAS = "use_bias"
LN_TYPE = "layernorm_type"
ACTIVATION = "activations"
LYR_TYPE = "layer_type"
ZERO_CEN = "zero_centered_gamma"
TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
class TestTransformer(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
mask,
mask,
]
def get_layer_name(self):
return "transformerlayer"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512
mlp_hidden_size = 2048
num_attention_heads = 8
layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
hidden_dropout = 0.0
attention_dropout = 0.0
intermediate_dropout = 0.0
mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(
RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
window_size = attrs.get(TransformerLayerAttr.WINDOW_SIZE, None)
rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init,
relative_embedding.num_attention_heads,
relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", rel_embedding_init
),
embedding_axes=relative_embedding.embedding_axes,
dtype=relative_embedding.dtype,
)
praxis_p = pax_fiddle.Config(
TransformerLayer,
name="transformer_layer",
params_init=kernel_init,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
use_bias=use_bias,
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_TransformerLayer,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", kernel_init
),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
return praxis_p, flax_cls
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import transformer_engine.jax
print("OK")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment