Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources ...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_mxfp8.cu test_cast_mxfp8.cu
# test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
test_cast_transpose.cu test_cast_transpose.cu
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};
constexpr size_t kBlockLen = 128;
enum ProcessingMethod {
CAST_ONLY,
// CAST_DBIAS,
// CAST_DBIAS_DACT,
// CAST_DACT,
// CAST_ACT
};
enum ActivationType {
Identity,
// GeLU,
// SiLU,
// ReLU,
// QGeLU,
// SReLU
};
template <typename InputType, typename OutputType>
void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale_out,
float* qscale_inv_out) {
float input_type_max_val = Quantized_Limits<InputType>::max();
float quant_type_max_val = Quantized_Limits<OutputType>::max();
float eps = opts.amax_epsilon;
amax = std::max(amax, eps);
float qscale = quant_type_max_val / amax;
if (std::isinf(qscale)) {
qscale = input_type_max_val;
}
if (std::isnan(qscale) || amax == 0) {
qscale = 1.0;
}
if (opts.force_pow_2_scales && qscale != 0.0) {
uint32_t scale_bits = *reinterpret_cast<uint32_t*>(&qscale);
// Scale must be positive, shift it
uint8_t exp = scale_bits >> 23;
ASSERT_FALSE(exp == 0) << "Subnormals in this path is a logic error.";
qscale = ldexpf(1.0f, static_cast<int32_t>(exp) - 127);
}
float qscale_inv = 1.0 / qscale;
*qscale_out = qscale;
*qscale_inv_out = qscale_inv;
}
template <typename InputType, typename OutputType>
void ref_quantize(const ProcessingMethod processing_method, const InputType* input,
const std::pair<size_t, size_t>& input_hw, OutputType* output, float* scale_inv,
OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) {
constexpr size_t kBlockLenX = kBlockLen;
constexpr size_t kBlockLenY = kBlockLen;
auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8.
return static_cast<OutputType>(float(element) * qscale);
};
size_t height = input_hw.first;
size_t width = input_hw.second;
size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX;
size_t blocks_y = (height + kBlockLenY - 1) / kBlockLenY;
// Find the absolute maximum value in the block
for (size_t block_x = 0; block_x < blocks_x; ++block_x) {
for (size_t block_y = 0; block_y < blocks_y; ++block_y) {
float amax = 0.0f;
// Calculate amax for a tile.
for (size_t i = 0; i < kBlockLenX; ++i) {
for (size_t j = 0; j < kBlockLenY; ++j) {
size_t x_pos = i + block_x * kBlockLenX;
size_t y_pos = j + block_y * kBlockLenY;
if (y_pos >= height || x_pos >= width) {
continue;
}
float val = static_cast<float>(input[y_pos * width + x_pos]);
amax = std::max(amax, std::abs(val));
}
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float qscale, qscale_inv;
scales_from_amax<InputType, OutputType>(amax, opts, &qscale, &qscale_inv);
// NOTE: This reference function outputs contigous scale tensors.
// It calculates a naive scale data format. Strides are handled
// in comparison.
if (scale_inv != nullptr) {
scale_inv[block_y * blocks_x + block_x] = qscale_inv;
}
if (scale_inv_t != nullptr) {
scale_inv_t[block_x * blocks_y + block_y] = qscale_inv;
}
for (size_t i = 0; i < kBlockLenX; ++i) {
for (size_t j = 0; j < kBlockLenY; ++j) {
size_t x_pos = i + block_x * kBlockLenX;
size_t y_pos = j + block_y * kBlockLenY;
if (y_pos >= height || x_pos >= width) {
continue;
}
if (output != nullptr) {
output[y_pos * width + x_pos] = quantize_element(input[y_pos * width + x_pos], qscale);
}
if (output_t != nullptr) {
output_t[x_pos * height + y_pos] =
quantize_element(input[y_pos * width + x_pos], qscale);
}
}
}
}
}
}
template <typename InputType, typename OutputType>
void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method,
const InputType* input,
const std::pair<size_t, size_t>& input_hw,
OutputType* output, float* scale_inv, OutputType* output_t,
float* scale_inv_t, const QuantizationOptions& opts) {
float input_type_max_val = Quantized_Limits<InputType>::max();
float quant_type_max_val = Quantized_Limits<OutputType>::max();
constexpr size_t kBlockLenX = kBlockLen;
auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8.
return static_cast<OutputType>(float(element) * qscale);
};
size_t height = input_hw.first;
size_t width = input_hw.second;
size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX;
size_t blocks_x_t = (height + kBlockLenX - 1) / kBlockLenX;
if (output != nullptr && scale_inv != nullptr) {
// Find the absolute maximum value in the block
for (size_t block_x = 0; block_x < blocks_x; ++block_x) {
for (size_t y = 0; y < height; ++y) {
float amax = 0.0f;
// Calculate amax for a tile.
for (size_t i = 0; i < kBlockLenX; ++i) {
size_t x_pos = i + block_x * kBlockLenX;
if (x_pos >= width) {
continue;
}
float val = static_cast<float>(input[y * width + x_pos]);
amax = std::max(amax, std::abs(val));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float qscale, qscale_inv;
scales_from_amax<InputType, OutputType>(amax, opts, &qscale, &qscale_inv);
scale_inv[y + height * block_x] = qscale_inv;
for (size_t i = 0; i < kBlockLenX; ++i) {
size_t x_pos = i + block_x * kBlockLenX;
if (x_pos >= width) {
continue;
}
output[y * width + x_pos] = quantize_element(input[y * width + x_pos], qscale);
}
}
}
}
if (output_t != nullptr && scale_inv_t != nullptr) {
// Find the absolute maximum value in the block
for (size_t block_x_t = 0; block_x_t < blocks_x_t; ++block_x_t) {
for (size_t x = 0; x < width; ++x) {
float amax = 0.0f;
// Calculate amax for a tile.
for (size_t i = 0; i < kBlockLenX; ++i) {
size_t y_pos = i + block_x_t * kBlockLenX;
if (y_pos >= height) {
continue;
}
float val = static_cast<float>(input[x + y_pos * width]);
amax = std::max(amax, std::abs(val));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float qscale, qscale_inv;
scales_from_amax<InputType, OutputType>(amax, opts, &qscale, &qscale_inv);
scale_inv_t[x + width * block_x_t] = qscale_inv;
for (size_t i = 0; i < kBlockLenX; ++i) {
size_t y_pos = i + block_x_t * kBlockLenX;
if (y_pos >= height) {
continue;
}
output_t[x * height + y_pos] = quantize_element(input[y_pos * width + x], qscale);
}
}
}
}
}
inline size_t scale_align_stride(size_t inner_elements) {
return ((inner_elements + 4u - 1u) / 4u) * 4u;
};
void compare_scaling_factors(const std::string& name, const float* test, const float* ref,
const size_t row_blocks, const size_t col_blocks,
const size_t test_stride, const size_t ref_stride) {
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i * test_stride + j;
const int ref_idx = i * ref_stride + j;
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
}
}
}
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows,
const size_t col_blocks) {
const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j;
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
}
}
}
template <typename InputType, typename OutputType>
void runTestCase(const ProcessingMethod processing_method, const std::vector<size_t>& shape,
const bool rowwise, const bool colwise, InputsFillCase fill_case,
const QuantizationOptions& opts) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen;
size_t blocks_y = (rows + kBlockLen - 1) / kBlockLen;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<float[]> ref_scale_inv = std::make_unique<float[]>(blocks_y * blocks_x);
std::unique_ptr<float[]> ref_scale_inv_t = std::make_unique<float[]>(blocks_y * blocks_x);
if (!rowwise) {
ref_output = nullptr;
ref_scale_inv = nullptr;
}
if (!colwise) {
ref_output_t = nullptr;
ref_scale_inv_t = nullptr;
}
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
ref_quantize<InputType, OutputType>(processing_method, input.rowwise_cpu_dptr<InputType>(),
{rows, cols}, ref_output.get(), ref_scale_inv.get(),
ref_output_t.get(), ref_scale_inv_t.get(), opts);
float atol = 0.0;
float rtol = 0.0;
if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), blocks_y, blocks_x, scale_align_stride(blocks_x),
blocks_x);
}
if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors("scale_inv_t", output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), blocks_x, blocks_y, scale_align_stride(blocks_y),
blocks_y);
}
}
template <typename InputType, typename OutputType>
void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
const std::vector<size_t>& shape, const bool rowwise,
const bool colwise, InputsFillCase fill_case,
const QuantizationOptions& opts) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen;
size_t blocks_x_t = (rows + kBlockLen - 1) / kBlockLen;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<float[]> ref_scale_inv = std::make_unique<float[]>(rows * blocks_x);
std::unique_ptr<float[]> ref_scale_inv_t = std::make_unique<float[]>(cols * blocks_x_t);
if (!rowwise) {
ref_output = nullptr;
ref_scale_inv = nullptr;
}
if (!colwise) {
ref_output_t = nullptr;
ref_scale_inv_t = nullptr;
}
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
ref_quantize_onedimensional_blocks<InputType, OutputType>(
processing_method, input.rowwise_cpu_dptr<InputType>(), {rows, cols}, ref_output.get(),
ref_scale_inv.get(), ref_output_t.get(), ref_scale_inv_t.get(), opts);
float atol = 0.0;
float rtol = 0.0;
if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv",
output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), rows, blocks_x);
}
if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv_t",
output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), cols, blocks_x_t);
}
}
std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16}, {65, 96}, {256, 256}, {993, 512},
{256, 65536}, {4096, 1632}, {1024, 1},
{16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
};
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
// ProcessingMethod::CAST_DBIAS,
// ProcessingMethod::CAST_DBIAS_DACT,
// ProcessingMethod::CAST_DACT,
// ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
// ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
std::vector<float> amax_epsilons = {
0.0f,
1.0f, // Make large to be observable.
};
} // namespace
class FusedCastFloat8BlockwiseTestSuite
: public ::testing::TestWithParam<std::tuple<
ProcessingMethod, ActivationType, std::vector<size_t>, transformer_engine::DType,
transformer_engine::DType, InputsFillCase, bool, float, bool>> {};
class FusedCastFloat8VectorwiseTestSuite
: public ::testing::TestWithParam<std::tuple<
ProcessingMethod, ActivationType, std::vector<size_t>, transformer_engine::DType,
transformer_engine::DType, InputsFillCase, bool, float, bool>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
if (getDeviceComputeCapability() < hopperComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationType Act_type = std::get<1>(GetParam());
const auto matrix_size = std::get<2>(GetParam());
const DType input_type = std::get<3>(GetParam());
const DType output_type = std::get<4>(GetParam());
const InputsFillCase fill_case = std::get<5>(GetParam());
const bool colwise = std::get<6>(GetParam());
const bool rowwise = true;
const float eps = std::get<7>(GetParam());
const bool force_pow_2 = std::get<8>(GetParam());
QuantizationOptions q_opts;
q_opts.force_pow_2_scales = force_pow_2;
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u;
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP();
}
// Skips non Act tests if the Activation type is not an identity
if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH(
Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(
output_type, OutputType,
runTestCase<InputType, OutputType>(processing_method, matrix_size, rowwise, colwise,
fill_case, q_opts););););
}
TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
if (getDeviceComputeCapability() < hopperComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationType Act_type = std::get<1>(GetParam());
const auto matrix_size = std::get<2>(GetParam());
const DType input_type = std::get<3>(GetParam());
const DType output_type = std::get<4>(GetParam());
const InputsFillCase fill_case = std::get<5>(GetParam());
const bool colwise = std::get<6>(GetParam());
const bool rowwise = true;
const float eps = std::get<7>(GetParam());
const bool force_pow_2 = std::get<8>(GetParam());
QuantizationOptions q_opts;
q_opts.force_pow_2_scales = force_pow_2;
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u;
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP();
}
// Skips non Act tests if the Activation type is not an identity
if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH(
Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(
output_type, OutputType,
runTestCaseOneDimensionalBlocks<InputType, OutputType>(
processing_method, matrix_size, rowwise, colwise, fill_case, q_opts););););
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY:
return "CAST_ONLY";
// case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
// case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
// case ProcessingMethod::CAST_DACT: return "CAST_DACT";
// case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default:
return "";
}
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity:
return "Identity";
// case ActivationType::GeLU: return "GeLU";
// case ActivationType::SiLU: return "SiLU";
// case ActivationType::ReLU: return "ReLU";
// case ActivationType::QGeLU: return "QGeLU";
// case ActivationType::SReLU: return "SReLU";
default:
return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, FusedCastFloat8BlockwiseTestSuite,
::testing::Combine(::testing::ValuesIn(processing_methods),
::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8BlockwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for (const auto& s : shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<3>(info.param)) + "X" +
test::typeName(std::get<4>(info.param)) + "X" +
test::caseName(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param)) + "X" +
std::to_string(std::get<7>(info.param) != 0.0f) + "X" +
std::to_string(std::get<8>(info.param));
return name;
});
INSTANTIATE_TEST_SUITE_P(
OperatorTest, FusedCastFloat8VectorwiseTestSuite,
::testing::Combine(::testing::ValuesIn(processing_methods),
::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8VectorwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for (const auto& s : shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<3>(info.param)) + "X" +
test::typeName(std::get<4>(info.param)) + "X" +
test::caseName(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param)) + "X" +
std::to_string(std::get<7>(info.param) != 0.0f) + "X" +
std::to_string(std::get<8>(info.param));
return name;
});
...@@ -19,169 +19,16 @@ ...@@ -19,169 +19,16 @@
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine; using namespace transformer_engine;
using namespace test; using namespace test;
namespace { namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){
using compute_t = float;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f);
#ifndef __HIP_PLATFORM_AMD__
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
#else
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
#endif
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn) { NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
if (sizeof(InputType) < sizeof(OutputType)) { if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return; return;
...@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); cudaGetDeviceProperties(&prop, 0);
if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) {
// Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend";
}
if (use_cudnn){ if (use_cudnn){
nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true); nvte_enable_cudnn_norm_bwd(true);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
} }
// Forward kernel // Forward kernel
...@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if (use_cudnn){ if (use_cudnn){
nvte_enable_cudnn_norm_fwd(false); nvte_enable_cudnn_norm_fwd(false);
nvte_enable_cudnn_norm_bwd(false); nvte_enable_cudnn_norm_bwd(false);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
} }
// Reference implementations // Reference implementations
...@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
&ref_amax, &ref_amax,
ref_scale, ref_scale,
zero_centered_gamma, zero_centered_gamma,
use_cudnn); use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(), compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(), mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(), gamma.rowwise_cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma, N, H, zero_centered_gamma,
use_cudnn); use_cudnn,
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -352,6 +219,7 @@ NormType, ...@@ -352,6 +219,7 @@ NormType,
transformer_engine::DType, transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType,
std::pair<size_t, size_t>, std::pair<size_t, size_t>,
bool,
bool>> {}; bool>> {};
TEST_P(NormTestSuite, TestNorm) { TEST_P(NormTestSuite, TestNorm) {
...@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) { ...@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) {
const DType output_type = std::get<3>(GetParam()); const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam()); const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam()); const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
); );
); );
} }
...@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases), ::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true)), ::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) { [](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
...@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<3>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" + std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param)); std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
return name; return name;
}); });
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
namespace test {
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (zero_centered_gamma_in_weight_dtype){
compute_t g = static_cast<compute_t>(0.f);
#ifndef __HIP_PLATFORM_AMD__
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
#else
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
#endif
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
if (amax) {
*amax = current_max;
}
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn,
const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
} // namespace
} // namespace test
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine; using namespace transformer_engine;
using namespace test; using namespace test;
...@@ -27,16 +28,6 @@ namespace { ...@@ -27,16 +28,6 @@ namespace {
using fp8e8m0 = byte; using fp8e8m0 = byte;
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RMSNorm"}
};
template <typename InputType, typename ScaleType, typename OutputType> template <typename InputType, typename ScaleType, typename OutputType>
void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr,
size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){
...@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) ...@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
32, 1); 32, 1);
} }
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
compute_t m;
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
OutputType* output,
const bool zero_centered_gamma){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1.0;
}
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = tmp;
}
}
}
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training, const bool zero_centered_gamma_in_weight_dtype) {
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); cudaGetDeviceProperties(&prop, 0);
...@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&gamma); fillUniform(&gamma);
fillUniform(&beta); fillUniform(&beta);
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
// Forward kernel // Forward kernel
float epsilon = 1e-5; float epsilon = 1e-5;
if (norm_type == NormType::LayerNorm){ if (norm_type == NormType::LayerNorm){
...@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
0); 0);
} }
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training); dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
...@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(), compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(), gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(), beta.rowwise_cpu_dptr<WeightType>(),
ref_output.get(),
ref_mu_ptr, ref_mu_ptr,
ref_rsigma_ptr, ref_rsigma_ptr,
N, H, N, H,
ref_output.get(), nullptr, // amax
zero_centered_gamma); 1.f, // scale
zero_centered_gamma,
true, // CuDNN is the only MXFP8 backend currently
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType, ...@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine::DType, transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType,
std::pair<size_t, size_t>, std::pair<size_t, size_t>,
bool, bool>> {}; bool, bool, bool>> {};
TEST_P(MxNormTestSuite, TestMxNorm) { TEST_P(MxNormTestSuite, TestMxNorm) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) { ...@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
const auto size = std::get<3>(GetParam()); const auto size = std::get<3>(GetParam());
const bool zero_centered_gamma = std::get<4>(GetParam()); const bool zero_centered_gamma = std::get<4>(GetParam());
const bool is_training = std::get<5>(GetParam()); const bool is_training = std::get<5>(GetParam());
const bool zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training); performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training, zero_centered_gamma_in_weight_dtype);
); );
); );
} }
...@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases), ::testing::ValuesIn(test_cases),
::testing::Values(true, false), ::testing::Values(true, false),
::testing::Values(true, false),
::testing::Values(true, false)), ::testing::Values(true, false)),
[](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) { [](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) {
std::string name = normToString.at(std::get<0>(info.param)) + "_" + std::string name = normToString.at(std::get<0>(info.param)) + "_" +
...@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<3>(info.param).first) + "X" + std::to_string(std::get<3>(info.param).first) + "X" +
std::to_string(std::get<3>(info.param).second) + "X" + std::to_string(std::get<3>(info.param).second) + "X" +
std::to_string(std::get<4>(info.param)) + "out" + std::to_string(std::get<4>(info.param)) + "out" +
std::to_string(int(std::get<5>(info.param)) + 1) + "x"; std::to_string(int(std::get<5>(info.param)) + 1) + "x" +
std::to_string(std::get<6>(info.param));
return name; return name;
}); });
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <random> #include <random>
#include <iostream>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <string> #include <string>
...@@ -111,8 +112,8 @@ struct scale_inv_meta { ...@@ -111,8 +112,8 @@ struct scale_inv_meta {
size_t type_size; size_t type_size;
}; };
NVTEShape convertShape(const std::vector<size_t>& shape) { NVTEShape convertShape(const std::vector<size_t>& s) {
return {shape.data(), shape.size()}; return nvte_make_shape(s.data(), s.size());
} }
std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul,4ul}; auto block_alignment = std::vector<size_t>{128ul, 4ul};
{ {
auto alignment = block_alignment[0]; auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
static_cast<size_t>(1)),
alignment) * alignment;
alignment = block_alignment[1]; alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
static_cast<size_t>(32)),
alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto alignment = block_alignment[1]; auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
static_cast<size_t>(32)),
alignment) * alignment;
alignment = block_alignment[0]; alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
static_cast<size_t>(1)),
alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
...@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
NVTE_ERROR("Invalid scaling mode!"); NVTE_ERROR("Invalid scaling mode!");
} }
...@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name, ...@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1), std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]}; shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0}; NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling // Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) { for (size_t i = 0; i < shape.ndim - 1; ++i) {
...@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name, ...@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name,
} }
if (columnwise) { if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data(); columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
columnwise_shape.ndim = columnwise_shape_vec.size();
} }
tensor_ = TensorWrapper(scaling_mode); tensor_ = TensorWrapper(scaling_mode);
...@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name, ...@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name,
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
} }
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, auto [rowwise_scale_meta, colwise_scale_meta] =
tensor_.scaling_mode()); get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) { if (rowwise) {
cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); auto scale_dtype = rowwise_scale_meta.type;
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
} }
if (columnwise) { if (columnwise) {
cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*)
cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); auto scale_dtype = colwise_scale_meta.type;
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
} }
} }
} }
...@@ -311,7 +357,8 @@ void Tensor::to_cpu() const { ...@@ -311,7 +357,8 @@ void Tensor::to_cpu() const {
sizeof(float), sizeof(float),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
...@@ -349,7 +396,8 @@ void Tensor::from_cpu() const { ...@@ -349,7 +396,8 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
...@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) {
if (columnwise_) { if (columnwise_) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_); NVTE_CHECK(columnwise_scale_inv_cpu_data_);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape); auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1){ if (num_scales == 1) {
rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv; rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{ } else {
std::uniform_int_distribution<uint8_t> dis(0, 127); std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>(); auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){ for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_); scale_inv_ptr[i] = dis(gen_);
} }
} }
} }
if (columnwise_) { if (columnwise_) {
auto num_scales = product(colwise_scale_meta.shape); auto num_scales = product(colwise_scale_meta.shape);
if (num_scales == 1){ if (num_scales == 1) {
columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv; columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{ } else {
std::uniform_int_distribution<uint8_t> dis(0, 127); std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>(); auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){ for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_); scale_inv_ptr[i] = dis(gen_);
} }
} }
...@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) {
} }
void Tensor::shareFP8Meta(const Tensor &other) { void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data(); auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
static_cast<DType>(my_rowwise_data.dtype),
my_rowwise_data.shape); my_rowwise_data.shape);
auto my_columnwise_data = tensor_.get_columnwise_data(); auto my_columnwise_data = tensor_.get_columnwise_data();
new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
static_cast<DType>(my_columnwise_data.dtype), static_cast<DType>(my_columnwise_data.dtype),
my_columnwise_data.shape); my_columnwise_data.shape);
auto other_amax = other.tensor_.get_amax(); auto other_amax = other.tensor_.get_amax();
new_tensor.set_amax(other_amax.data_ptr, new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
static_cast<DType>(other_amax.dtype),
other_amax.shape); other_amax.shape);
auto other_scale = other.tensor_.get_scale(); auto other_scale = other.tensor_.get_scale();
new_tensor.set_scale(other_scale.data_ptr, new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
static_cast<DType>(other_scale.dtype),
other_scale.shape); other_scale.shape);
auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
...@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) { ...@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) {
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) { std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
std::vector<size_t> ret; std::vector<size_t> ret;
size_t current_i = i; size_t current_i = i;
for (size_t current = shape.ndim - 1; for (size_t current = shape.ndim - 1; current > 0; --current) {
current > 0;
--current) {
ret.push_back(current_i % shape.data[current]); ret.push_back(current_i % shape.data[current]);
current_i /= shape.data[current]; current_i /= shape.data[current];
} }
...@@ -812,8 +857,7 @@ bool isFp8Type(DType type) { ...@@ -812,8 +857,7 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
} }
int32_t getDeviceComputeCapability() int32_t getDeviceComputeCapability() {
{
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0); cudaGetDeviceProperties(&deviceProp, 0);
return 10 * deviceProp.major + deviceProp.minor; return 10 * deviceProp.major + deviceProp.minor;
......
...@@ -121,7 +121,7 @@ class Tensor { ...@@ -121,7 +121,7 @@ class Tensor {
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {}
Tensor() {} Tensor() {}
...@@ -148,25 +148,19 @@ class Tensor { ...@@ -148,25 +148,19 @@ class Tensor {
if (scale_inv != nullptr) { if (scale_inv != nullptr) {
cudaFree(scale_inv); cudaFree(scale_inv);
} }
if (columnwise_data_ptr != nullptr){ if (columnwise_data_ptr != nullptr) {
cudaFree(columnwise_data_ptr); cudaFree(columnwise_data_ptr);
} }
if (columnwise_scale_inv != nullptr){ if (columnwise_scale_inv != nullptr) {
cudaFree(columnwise_scale_inv); cudaFree(columnwise_scale_inv);
} }
} }
NVTETensor data() const noexcept { NVTETensor data() const noexcept { return tensor_.data(); }
return tensor_.data();
}
NVTEShape rowwise_shape() const noexcept { NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; }
return tensor_.get_rowwise_data().shape;
}
NVTEShape columnwise_shape() const noexcept { NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; }
return tensor_.get_columnwise_data().shape;
}
NVTEShape rowwise_scale_inv_shape() const { NVTEShape rowwise_scale_inv_shape() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
...@@ -233,6 +227,8 @@ class Tensor { ...@@ -233,6 +227,8 @@ class Tensor {
T *rowwise_cpu_scale_inv_ptr(){ T *rowwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -244,6 +240,8 @@ class Tensor { ...@@ -244,6 +240,8 @@ class Tensor {
T *columnwise_cpu_scale_inv_ptr(){ T *columnwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types; ...@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type); bool isFp8Type(DType type);
int32_t getDeviceComputeCapability(); int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100; constexpr int32_t blackwellComputeCapability = 100;
} // namespace test } // namespace test
......
...@@ -25,3 +25,5 @@ filterwarnings= ...@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning ignore:Scan loop is disabled for fused ring attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target is deprecated
ignore:jax.extend.ffi.ffi_lowering is deprecated
...@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import ( ...@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
ScaledTensor, ScaledTensor,
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeAxis, QuantizeLayout,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] ...@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)] LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available() is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = [] supported_scaling_modes = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
if is_fp8_supported: if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
if is_mxfp8_supported: if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
def is_shape_supported_by_mxfp8(input_shape): def is_shape_supported_by_mxfp8(input_shape):
try: try:
if isinstance(input_shape, type(pytest.param(0))): if isinstance(input_shape, type(pytest.param(0))):
input_shape = input_shape.values[0] input_shape = input_shape.values[0]
ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
return True return True
except: except:
# get_scale_shapes will raise an exception if the shape is not supported # get_scale_shapes will raise an exception if the shape is not supported
...@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): ...@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
if isinstance(a, ScaledTensor1x): if isinstance(a, ScaledTensor1x):
if a.layout == "T": if a.data_layout == "T":
b_transpose = jnp.transpose(b, (-1, *range(b.ndim - 1))) flatten_axis = a.data.ndim - a.flatten_axis
b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype) assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
else: else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype) assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
...@@ -141,7 +142,8 @@ class TestActivation: ...@@ -141,7 +142,8 @@ class TestActivation:
def test_act_grad(self, shape, activation_type): def test_act_grad(self, shape, activation_type):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32) x = jax.random.uniform(key, shape, jnp.float32)
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
...@@ -159,7 +161,8 @@ class TestActivation: ...@@ -159,7 +161,8 @@ class TestActivation:
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
...@@ -167,9 +170,9 @@ class TestActivation: ...@@ -167,9 +170,9 @@ class TestActivation:
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_axis=QuantizeAxis.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
...@@ -182,19 +185,22 @@ class TestActivation: ...@@ -182,19 +185,22 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8( def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis self, random_inputs, activation_type, output_type, q_layout
): ):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_axis=q_axis, q_layout=q_layout,
) )
te_output = tex.act_lu(x, activation_type, te_quantizer) te_output = tex.act_lu(x, activation_type, te_quantizer)
...@@ -203,19 +209,21 @@ class TestActivation: ...@@ -203,19 +209,21 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest_parametrize_wrapper("shape", [(128, 128)]) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_block_scaling_fp8( def test_act_forward_with_block_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis self, random_inputs, activation_type, output_type, q_layout
): ):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_axis=q_axis scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
) )
output = tex.act_lu(x, activation_type, quantizer) output = tex.act_lu(x, activation_type, quantizer)
...@@ -324,9 +332,11 @@ class TestNorm: ...@@ -324,9 +332,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8( def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
): ):
""" """
Test transformer_engine.jax.layernorm.layernorm Test transformer_engine.jax.layernorm.layernorm
...@@ -335,7 +345,9 @@ class TestNorm: ...@@ -335,7 +345,9 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_axis=q_axis scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
) )
self._test_norm_grad( self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
...@@ -351,7 +363,7 @@ class TestNorm: ...@@ -351,7 +363,7 @@ class TestNorm:
inp_dtype, inp_dtype,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3) subkeys = jax.random.split(key, 3)
...@@ -363,7 +375,7 @@ class TestNorm: ...@@ -363,7 +375,7 @@ class TestNorm:
gamma = jnp.asarray(gamma, inp_dtype) gamma = jnp.asarray(gamma, inp_dtype)
quantizer, ref_quantizer = QuantizerFactory.create( quantizer, ref_quantizer = QuantizerFactory.create(
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_axis=q_axis n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
) )
if norm_type == "layernorm": if norm_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
...@@ -391,9 +403,11 @@ class TestNorm: ...@@ -391,9 +403,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8( def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
): ):
if norm_type == "rmsnorm" and zero_centered_gamma is True: if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
...@@ -406,8 +420,8 @@ class TestNorm: ...@@ -406,8 +420,8 @@ class TestNorm:
epsilon=epsilon, epsilon=epsilon,
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_axis=q_axis, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
...@@ -423,8 +437,8 @@ class TestNorm: ...@@ -423,8 +437,8 @@ class TestNorm:
epsilon=epsilon, epsilon=epsilon,
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.MXFP8_1D_SCALING,
q_axis=QuantizeAxis.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
) )
...@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = { ...@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
} }
ALL_QUANTIZE_TEST_SHAPES = [ ALL_QUANTIZE_TEST_SHAPES = [
(128, 128), (32, 64),
(4, 256, 512), (2, 64, 32),
] ]
QUANTIZE_TEST_SHAPES = { QUANTIZE_TEST_SHAPES = {
"L0": [ "L0": [
(256, 128), (32, 256, 128),
(64, 16, 2, 256), (64, 32, 32, 256),
], ],
"L2": ALL_QUANTIZE_TEST_SHAPES, "L2": ALL_QUANTIZE_TEST_SHAPES,
} }
...@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = { ...@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
class TestQuantize: class TestQuantize:
""" """
Purely quantization related tests that will always test on a wider set of types and shapes Purely quantization related tests that will always test on a wider set of types and shapes
""" """
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_dtype=q_dtype, q_dtype=q_dtype,
q_axis=q_axis, q_layout=q_layout,
) )
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype) x = jax.random.uniform(key, input_shape, in_dtype)
scaled_tensor = quantizer.quantize(x) scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x) assert_dequantized_scaled_tensor(scaled_tensor, x)
def test_quantize_bitwise(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): def test_quantize_bitwise(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
input_shape
): ):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
jax_output = _jax_quantize(input, quantizer=jax_quantizer) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer) te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
...@@ -508,10 +526,14 @@ class TestFusedQuantize: ...@@ -508,10 +526,14 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_axis): "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
transpose_axis = -1 )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( @pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape input_shape
): ):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
...@@ -520,35 +542,37 @@ class TestFusedQuantize: ...@@ -520,35 +542,37 @@ class TestFusedQuantize:
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
jax_quantizer, te_quantizer = QuantizerFactory.create( jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
te_output, te_dbias = jit(lambda input: tex.quantize_dbias(input, quantizer=te_quantizer))( te_output, te_dbias = jit(
input lambda input: tex.quantize_dbias(
input, quantizer=te_quantizer, flatten_axis=flatten_axis
) )
)(input)
jax_output, jax_dbias = jit( jax_output, jax_dbias = jit(
lambda input: _jax_quantize_dbias( lambda input: _jax_quantize_dbias(
input, input, quantizer=jax_quantizer, flatten_axis=flatten_axis
quantizer=jax_quantizer,
) )
)(input) )(input)
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
assert_allclose(jax_dbias, te_dbias) assert_allclose(te_dbias, jax_dbias)
def _test_quantize_dact_dbias( def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1) dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
jax_quantizer, te_quantizer = QuantizerFactory.create( jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
is_casted_output = te_quantizer is not None is_casted_output = te_quantizer is not None
...@@ -573,12 +597,12 @@ class TestFusedQuantize: ...@@ -573,12 +597,12 @@ class TestFusedQuantize:
)(dz, x) )(dz, x)
if is_casted_output: if is_casted_output:
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
else: else:
assert_allclose(jax_output, te_output) assert_allclose(te_output, jax_output)
if is_dbias: if is_dbias:
assert_allclose(jax_dbias, te_dbias) assert_allclose(te_dbias, jax_dbias)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
...@@ -594,10 +618,10 @@ class TestFusedQuantize: ...@@ -594,10 +618,10 @@ class TestFusedQuantize:
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=in_dtype, out_dtype=in_dtype,
scaling_mode=ScalingMode.NVTE_NO_SCALING, scaling_mode=ScalingMode.NO_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=QuantizeAxis.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -605,18 +629,20 @@ class TestFusedQuantize: ...@@ -605,18 +629,20 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling( def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
): ):
self._test_quantize_dact_dbias( self._test_quantize_dact_dbias(
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=q_axis, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
...@@ -626,9 +652,11 @@ class TestFusedQuantize: ...@@ -626,9 +652,11 @@ class TestFusedQuantize:
) )
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_mxfp8_scaling( def test_quantize_dact_dbias_mxfp8_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
): ):
if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0: if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes. # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
...@@ -642,78 +670,78 @@ class TestFusedQuantize: ...@@ -642,78 +670,78 @@ class TestFusedQuantize:
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.MXFP8_1D_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=q_axis, q_layout=q_layout,
) )
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if layout[0] == "T": if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2) a = jnp.swapaxes(a, -1, -2)
if layout[1] == "T": if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2) b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b) return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, layout): def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform( x = jax.random.uniform(
subkeys[0], subkeys[0],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m), (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16, dtype=jnp.bfloat16,
) / jnp.sqrt(k) ) / jnp.sqrt(k)
w = jax.random.uniform( w = jax.random.uniform(
subkeys[1], subkeys[1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k), (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16, dtype=jnp.bfloat16,
) / jnp.sqrt(n) ) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims) return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_bf16(self, m, n, k, layout): def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims) primitive_out = tex.gemm(x, w, contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, layout): def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
) )
primitive_out = tex.gemm( primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
) )
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=q_dtype)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k): def test_dense_grad_bf16(self, m, n, k):
layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
def primitive_func(x, w, contracting_dims): def primitive_func(x, w, contracting_dims):
primitive_out = dense(x, w, contracting_dims=contracting_dims) primitive_out = dense(x, w, contracting_dims=contracting_dims)
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, w, layout): def ref_func(x, w, data_layout):
return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, layout)) return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
...@@ -722,19 +750,19 @@ class TestDense: ...@@ -722,19 +750,19 @@ class TestDense:
primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func( primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
x, w, contracting_dims x, w, contracting_dims
) )
ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, layout) ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
key = jax.random.PRNGKey(1) key = jax.random.PRNGKey(1)
bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)
...@@ -745,9 +773,9 @@ class TestDense: ...@@ -745,9 +773,9 @@ class TestDense:
) )
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, w, bias, layout): def ref_func(x, w, bias, data_layout):
return jnp.mean( return jnp.mean(
self._ref_gemm_with_jnp_dot(x, w, layout) + jnp.expand_dims(bias, axis=0) self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
) )
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
...@@ -757,13 +785,15 @@ class TestDense: ...@@ -757,13 +785,15 @@ class TestDense:
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
) )
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
) )
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(x, w, bias, layout) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
...@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 128)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
...@@ -800,7 +830,7 @@ class TestFusedDense: ...@@ -800,7 +830,7 @@ class TestFusedDense:
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -856,7 +886,7 @@ class TestFusedDense: ...@@ -856,7 +886,7 @@ class TestFusedDense:
x, w, gamma, beta x, w, gamma, beta
) )
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -873,7 +903,7 @@ class TestFusedDense: ...@@ -873,7 +903,7 @@ class TestFusedDense:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 256)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
...@@ -886,7 +916,7 @@ class TestFusedDense: ...@@ -886,7 +916,7 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -898,13 +928,13 @@ class TestFusedDense: ...@@ -898,13 +928,13 @@ class TestFusedDense:
x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
kernel_1 = jax.random.normal( kernel_1 = jax.random.normal(
subkeys[1], (k, len(activation_type) * n), jnp.bfloat16 subkeys[1], (k, len(activation_type), n), jnp.bfloat16
) / jnp.sqrt(k) ) / jnp.sqrt(k)
kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n) kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
beta = None # was tested in TestNorm beta = None # was tested in TestNorm
if use_bias: if use_bias:
bias_1 = jax.random.normal(subkeys[3], (len(activation_type) * n), jnp.bfloat16) bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else: else:
bias_1 = None bias_1 = None
...@@ -963,7 +993,7 @@ class TestFusedDense: ...@@ -963,7 +993,7 @@ class TestFusedDense:
value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -1039,19 +1069,19 @@ class TestGroupedDense: ...@@ -1039,19 +1069,19 @@ class TestGroupedDense:
subkeys = jax.random.split(key, len(shape_list) * 2) subkeys = jax.random.split(key, len(shape_list) * 2)
lhs_list, rhs_list, contracting_dims_list = [], [], [] lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), layout) in enumerate(zip(shape_list, layout_list)): for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform( lhs = jax.random.uniform(
subkeys[2 * i], subkeys[2 * i],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m), (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype, dtype=dtype,
) )
rhs = jax.random.uniform( rhs = jax.random.uniform(
subkeys[2 * i + 1], subkeys[2 * i + 1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k), (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype, dtype=dtype,
) )
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs) lhs_list.append(lhs)
......
...@@ -48,31 +48,7 @@ class TestDistributedSelfAttn: ...@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias # for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) def impl_test_self_attn(
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -83,7 +59,9 @@ class TestDistributedSelfAttn: ...@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -137,6 +115,80 @@ class TestDistributedSelfAttn: ...@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
) )
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
use_shardy=True,
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
...@@ -203,37 +255,23 @@ class TestDistributedCrossAttn: ...@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize( DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param( pytest.param(
QKVLayout.THD_THD_THD, QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
), ),
], ]
)
@pytest.mark.parametrize( DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
"load_balanced", # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
) pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]
class TestDistributedContextParallelSelfAttn: class TestDistributedContextParallelSelfAttn:
def impl_test_context_parallel_attn( def impl_test_context_parallel_attn(
...@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn: ...@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
cp_strategy, cp_strategy,
use_shardy,
use_scan_ring=False,
): ):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
if use_scan_ring:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
jax.config.update("jax_use_shardy_partitioner", use_shardy)
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None bias_shape = None
dropout_prob = 0.0 dropout_prob = 0.0
...@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn: ...@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
runner.test_backward() runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_allgather_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.ALL_GATHER,
use_shardy=True,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
def test_context_parallel_allgather_attn( def test_context_parallel_allgather_attn(
self, self,
device_count, device_count,
...@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
): ):
if qkv_layout.is_thd(): self.impl_test_context_parallel_attn(
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
mesh_axes, mesh_axes,
...@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn: ...@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
CPStrategy.ALL_GATHER, CPStrategy.ALL_GATHER,
use_shardy=False,
) )
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_scan", "use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
...@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn: ...@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced, load_balanced,
use_scan, use_scan,
): ):
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
self.impl_test_context_parallel_attn( self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn: ...@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
CPStrategy.RING, CPStrategy.RING,
use_shardy=False,
use_scan_ring=use_scan,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_ring_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
) )
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
return
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
......
...@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = { ...@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
} }
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
...@@ -86,6 +86,7 @@ class TestDistributedLayernorm: ...@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm( def test_layernorm(
self, self,
device_count, device_count,
...@@ -97,7 +98,9 @@ class TestDistributedLayernorm: ...@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma, zero_centered_gamma,
shard_weights, shard_weights,
fp8_recipe, fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6 epsilon = 1e-6
ln_type = "layernorm" ln_type = "layernorm"
q_dtype = jnp.float8_e4m3fn q_dtype = jnp.float8_e4m3fn
...@@ -168,6 +171,7 @@ class TestDistributedLayernorm: ...@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_rmsnorm( def test_rmsnorm(
self, self,
device_count, device_count,
...@@ -178,7 +182,9 @@ class TestDistributedLayernorm: ...@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype, dtype,
shard_weights, shard_weights,
fp8_recipe, fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6 epsilon = 1e-6
ln_type = "rmsnorm" ln_type = "rmsnorm"
q_dtype = jnp.float8_e4m3fn q_dtype = jnp.float8_e4m3fn
......
...@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory ...@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
...@@ -45,11 +45,17 @@ if is_mxfp8_supported: ...@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64 INTERMEDIATE = 64
...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs(): ...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append( configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP: ...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal( k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in) ) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE INTERMEDIATE
) )
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else: else:
b1 = None b1 = None
...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP: ...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else: else:
layernorm_input_axes = None layernorm_input_axes = None
dot_1_input_axes = None dot_1_input_axes = dot_2_input_axes = None
dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
...@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP: ...@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes, norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes, dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes, dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type, activation_type=activation_type,
quantizer_sets=quantizer_sets, quantizer_sets=quantizer_sets,
) )
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) def _test_layernorm_mlp_grad(
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm" layernorm_type = "rmsnorm"
...@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP: ...@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp")) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP: ...@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy=False,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
use_shardy=True,
)
def _test_layernorm_mlp( def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8,
fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm" layernorm_type = "rmsnorm"
...@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP: ...@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP:
activations=activation_type, activations=activation_type,
use_bias=use_bias, use_bias=use_bias,
) )
params_single = ln_mlp_single.init(init_rngs, x) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True params_single, x, deterministic=True
) )
...@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP: ...@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=LN_SCALE_AXES,
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias, use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_1=BIAS_1_AXES,
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp", name="mlp",
) )
params_sharded = ln_mlp_sharded.init(init_rngs, x) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True params_sharded, x, deterministic=True
) )
...@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP: ...@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): @pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
) )
# TODO: debug @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
# @pytest_parametrize_wrapper( @pytest_parametrize_wrapper("use_bias", [True, False])
# "activation_type", [("gelu",), ("gelu", "linear")] @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# ) @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) def test_layernorm_mlp_layer_fp8(
# @pytest_parametrize_wrapper("dtype", DTYPES) self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) ):
# def test_layernorm_fp8_mlp_layer( self._test_layernorm_mlp(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe mesh_config,
# ): activation_type,
# self._test_layernorm_mlp( use_bias,
# mesh_config, activation_type, use_bias, input_shape, dtype, input_shape,
# use_fp8=True, fp8_recipe=fp8_recipe dtype,
# ) use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=False,
)
...@@ -28,14 +28,16 @@ class TestDistributedSoftmax: ...@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen) mask = make_causal_mask(batch, sqelen)
else: else:
mask = make_self_mask(batch, sqelen) mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
if not bad_sharding: if not bad_sharding:
x_pspec = PartitionSpec( x_pspec = PartitionSpec(
...@@ -45,6 +47,10 @@ class TestDistributedSoftmax: ...@@ -45,6 +47,10 @@ class TestDistributedSoftmax:
x_pspec = PartitionSpec( x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
) )
if broadcast_batch_mask:
mask_pspec = PartitionSpec(None, None, None, None)
else:
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
...@@ -67,16 +73,7 @@ class TestDistributedSoftmax: ...@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor) output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output) return jnp.mean(output)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) def impl_test_softmax(
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -87,15 +84,20 @@ class TestDistributedSoftmax: ...@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
broadcast_batch_mask,
use_shardy,
): ):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial( target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
) )
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs( (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
) )
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -129,4 +131,70 @@ class TestDistributedSoftmax: ...@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert "Sharding the hidden dimension is not supported" in str(w), ( assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for " "Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension." "unsupported sharding in the hidden dimension."
f"{str(w)}"
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True,
) )
...@@ -39,7 +39,7 @@ def enable_fused_attn(): ...@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
QUANTIZE_RECIPES = [] QUANTIZE_RECIPES = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
...@@ -215,12 +215,53 @@ ATTRS = [ ...@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, },
# attrs22 # attrs22
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: None,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs26
{ {
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2), _KEY_OF_WINDOW_SIZE: (2, 2),
}, },
# attrs27
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: None,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2),
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -313,7 +354,7 @@ class BaseRunner: ...@@ -313,7 +354,7 @@ class BaseRunner:
test_others, test_others,
test_layer, test_layer,
) )
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME updated_state[0], QuantizeConfig.COLLECTION_NAME
) )
...@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner): ...@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(2024) data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) mask_shape = (batch, 1, seqlen, seqlen)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask mask = causal_mask
else: else:
mask = padded_mask mask = padded_mask
ref_masks = (1 - mask,) ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens. test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
......
...@@ -18,6 +18,7 @@ from utils import assert_allclose ...@@ -18,6 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.flax.module import Softmax
def catch_unsupported(method): def catch_unsupported(method):
...@@ -94,7 +95,6 @@ class SoftmaxRunner: ...@@ -94,7 +95,6 @@ class SoftmaxRunner:
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self): def test_forward(self):
""" """
Test transformer_engine.jax.softmax.softmax fwd rule Test transformer_engine.jax.softmax.softmax fwd rule
...@@ -104,7 +104,6 @@ class SoftmaxRunner: ...@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype) assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self): def test_backward(self):
""" """
Test transformer_engine.jax.softmax.softmax bwd rule Test transformer_engine.jax.softmax.softmax bwd rule
...@@ -141,6 +140,50 @@ class SoftmaxRunner: ...@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
class SoftmaxPrimitivesRunner(SoftmaxRunner):
"""
Jax Softmax Primitives runner
"""
@catch_unsupported
def test_forward(self):
return super().test_forward()
@catch_unsupported
def test_backward(self):
return super().test_backward()
class SoftmaxModuleRunner:
"""
Jax Softmax Module runner
"""
module_runner: SoftmaxRunner
bias: None
def __init__(self, module_runner, bias):
self.module_runner = module_runner
self.bias = bias
def test_forward(self):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner = self.module_runner
runner._setup_inputs()
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
assert_allclose(module_out, reference_out, dtype=runner.dtype)
# Run softmax primitives test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b, s_q, s_kv, h", "b, s_q, s_kv, h",
[ [
...@@ -165,7 +208,7 @@ class SoftmaxRunner: ...@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest.param(jnp.float16, id="FP16"), pytest.param(jnp.float16, id="FP16"),
], ],
) )
class TestSoftmax: class TestSoftmaxPrimitives:
""" """
Test transformer_engine.jax.softmax.softmax Test transformer_engine.jax.softmax.softmax
""" """
...@@ -175,7 +218,7 @@ class TestSoftmax: ...@@ -175,7 +218,7 @@ class TestSoftmax:
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
...@@ -183,5 +226,48 @@ class TestSoftmax: ...@@ -183,5 +226,48 @@ class TestSoftmax:
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward() runner.test_backward()
# Run Softmax module test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest.param(8, 511, 512, 16, id="8-511-512-16"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmaxModule:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
...@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import ( ...@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
) )
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
def _get_raw_data(quantized_tensor): def _get_raw_data(quantized_tensor):
...@@ -228,6 +232,273 @@ class MiniOptimizer: ...@@ -228,6 +232,273 @@ class MiniOptimizer:
weight.data.copy_(master_weight) weight.data.copy_(master_weight)
class MiniFSDP:
def __init__(self, weights, lr, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
tensor_indices = []
cumulative_length = 0
for tensor in raw_data_list:
length = tensor.size(0)
tensor_indices.append((cumulative_length, cumulative_length + length))
cumulative_length += length
# Build shard index mappings
self.weight_indices = []
self.shard_indices = []
for idx, (start, end) in enumerate(tensor_indices):
shard_start = rank * shard_size
shard_end = shard_start + shard_size
adjusted_end = min(shard_end, original_length)
if start <= adjusted_end and end >= shard_start:
start_idx = max(start, shard_start)
end_idx = min(end, adjusted_end)
self.weight_indices.append((start_idx - start, end_idx - start))
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
else:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
shard_start, shard_end = self.shard_indices[i]
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight_shard = high_precision_init_val.to(weight.device).float()[
weight_start:weight_end
]
else:
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
)
def _flatten_tensors_with_pad(self, tensors):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size = dist.get_world_size(self.dp_group)
flatten_tensor = torch.cat(tensors)
original_length = flatten_tensor.size(0)
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
flatten_tensor = torch.cat(
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
)
return flatten_tensor, original_length
def zero_grad(self):
for weight in self.weights:
weight.grad = None
weight.main_grad.zero_()
def step(self):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
self.weights, self.master_weights, self.shard_indices
):
if master_weight is None:
continue
# Extract the local gradient shard for this weight
grad = self.local_main_grad_shard[shard_start:shard_end]
# Update the master weight using gradient descent
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
[idx[0] for idx in self.weight_indices],
self.dp_group,
local_weights,
)
else:
for weight, master_weight in zip(self.local_weights, self.master_weights):
if master_weight is None:
continue
# Copy updated master weights to local weights
weight.data.copy_(master_weight)
# Step 4: All-gather updated weights across processes
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
# Configuration constants
NUM_STEPS = 100
SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": False,
}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group): def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct""" """Make sure the implementation of zero-1 optimizer is correct"""
rank = dist.get_rank(dp_group) rank = dist.get_rank(dp_group)
...@@ -389,6 +660,7 @@ def main(argv=None, namespace=None): ...@@ -389,6 +660,7 @@ def main(argv=None, namespace=None):
dp_group = dist.new_group(backend="nccl") dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group) _test_zero_1(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
dist.destroy_process_group() dist.destroy_process_group()
return 0 return 0
......
...@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import ( ...@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling, MXFP8BlockScaling,
DelayedScaling, DelayedScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling,
Format, Format,
Recipe, Recipe,
) )
...@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe: ...@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe:
return MXFP8BlockScaling() return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs": if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling() return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe() return te.fp8.get_default_fp8_recipe()
...@@ -86,7 +89,7 @@ def main(argv=None, namespace=None): ...@@ -86,7 +89,7 @@ def main(argv=None, namespace=None):
# Quantization scheme # Quantization scheme
QUANTIZATION = args.quantization QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8"): if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 32 SEQ_LEN = 32
BATCH_SIZE = 32 BATCH_SIZE = 32
...@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed): ...@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN(output_distributed, target).backward() LOSS_FN(output_distributed, target).backward()
def _loss_backward_dw(model_single_node, model_distributed):
model_single_node.backward_dw()
model_distributed.backward_dw()
def _alloc_main_grad(model_single_node, model_distributed): def _alloc_main_grad(model_single_node, model_distributed):
for model in [model_single_node, model_distributed]: for model in [model_single_node, model_distributed]:
for param in model.parameters(): for param in model.parameters():
...@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -492,6 +504,7 @@ def test_linear(): ...@@ -492,6 +504,7 @@ def test_linear():
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
...@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -665,6 +682,7 @@ def test_layernorm_linear(): ...@@ -665,6 +682,7 @@ def test_layernorm_linear():
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"zero_centered_gamma": False}, {"zero_centered_gamma": False},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column"]: for parallel_mode in ["column"]:
...@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -769,6 +790,7 @@ def test_layernorm_mlp(): ...@@ -769,6 +790,7 @@ def test_layernorm_mlp():
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
......
...@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2: ...@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count()) NUM_PROCS: int = min(4, torch.cuda.device_count())
...@@ -48,7 +51,7 @@ def _run_test(quantization): ...@@ -48,7 +51,7 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) @pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization): def test_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -56,4 +59,6 @@ def test_distributed(quantization): ...@@ -56,4 +59,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available: if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization) _run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128):
pid = tl.program_id(0)
idx = pid * BLOCK + tl.arange(0, BLOCK)
mask = idx < M * N
row = idx // N
col = idx % N
y_offset = row * y_str0 + col * y_str1
x_offset = row * N + col
s_offset = row * N + col
y = tl.load(y_ptr + y_offset, mask=mask)
x = tl.load(x_ptr + x_offset, mask=mask)
s = tl.load(s_ptr + s_offset, mask=mask)
tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask)
def fused_fma(y, x, s, BLOCK=128):
"""
Fused multiply-add operation (y = y + x * s).
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
This function also supports cases where 'y' is non-contiguous in memory.
"""
assert (
y.shape == x.shape == s.shape and y.dim() == 2
), "All tensors must be 2D with the same shape"
assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous"
M, N = y.shape
grid = ((M * N + BLOCK - 1) // BLOCK,)
fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK)
return y
class CuBLASRefBlockwiseGemm:
"""
A cuBLAS compatible reference implementation of subchannel GEMM.
"""
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
demunged_sx: torch.Tensor,
demunged_sw: torch.Tensor,
quant_tile_shape_x: Tuple[int, int],
quant_tile_shape_w: Tuple[int, int],
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
use_split_accumulator: bool = False,
) -> torch.Tensor:
# demunge scale shapes for cuBLAS
is_a_1d_scaled = quant_tile_shape_x[0] == 1
is_b_1d_scaled = quant_tile_shape_w[0] == 1
M, K = qx.shape
N, K = qw.shape
# mm_tile_shape = (tile_m, tile_n, tile_k)
mm_tile_shape = (
quant_tile_shape_x[0],
quant_tile_shape_w[0],
quant_tile_shape_w[1],
)
if bias is not None and bias.numel():
# To match cuBLAS more closely when bias is applied,
# the reference accumulates into float32, and cast to
# bfloat16 is deferred until after the GEMM.
out_dtype_for_ref = torch.float32
else:
out_dtype_for_ref = out_dtype
y = self.qgemm_blockwise_2d(
qx,
qw,
out_dtype_for_ref,
demunged_sx,
demunged_sw,
mm_tile_shape,
use_split_accumulator,
is_a_1d_scaled,
is_b_1d_scaled,
)
if bias is not None and bias.numel():
y += bias
y = y.to(dtype=out_dtype)
# cublas accumulation first convert to output dtype, then accumulate.
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y
@classmethod
def qgemm_blockwise_2d(
cls,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
mm_tile_shape: Tuple[int, int, int],
use_split_accumulator: bool,
is_a_1d_scaled: bool,
is_b_1d_scaled: bool,
) -> torch.Tensor:
"""
Difference between cuBLAS and CUTLASS GEMM implementations:
- cuBLAS accumulation equation: use different equation for each scaling mode.
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
"""
M, K = qx.shape
N, K_w = qw.shape
assert K == K_w, "K dimension mismatch between qx and qw"
tile_len = 128
# Calculate grid sizes without padding
grid_m = (M + tile_len - 1) // tile_len
grid_n = (N + tile_len - 1) // tile_len
grid_k = (K + tile_len - 1) // tile_len
block_m, block_n, block_k = mm_tile_shape
scale_m_per_tile = tile_len // block_m
scale_n_per_tile = tile_len // block_n
assert block_k == tile_len, "block_k must be equal to tile_len"
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
# Validate shapes of sx and sw
scale_m_per_tensor = (M + block_m - 1) // block_m
scale_n_per_tensor = (N + block_n - 1) // block_n
assert sx.shape == (
scale_m_per_tensor,
grid_k,
), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}"
assert sw.shape == (
scale_n_per_tensor,
grid_k,
), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}"
for i in range(grid_m):
m_start = i * tile_len
m_end = min(m_start + tile_len, M)
m_size = m_end - m_start
for j in range(grid_n):
n_start = j * tile_len
n_end = min(n_start + tile_len, N)
n_size = n_end - n_start
y_block = y[m_start:m_end, n_start:n_end]
for k in range(grid_k):
k_start = k * tile_len
k_end = min(k_start + tile_len, K)
k_size = k_end - k_start
qx_block = (
qx[m_start:m_end, k_start:k_end].clone().contiguous()
) # Shape: [m_size, k_size]
qw_block = (
qw[n_start:n_end, k_start:k_end].clone().contiguous()
) # Shape: [n_size, k_size]
# Extract scaling factors for the current blocks
sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze(
-1
)
sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0)
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
y_partial = torch._scaled_mm(
qx_block,
qw_block.t(),
scale_a=one,
scale_b=one,
out_dtype=torch.float32,
use_fast_accum=not use_split_accumulator,
)
# Accumulate the partial result
if is_a_1d_scaled and is_b_1d_scaled:
# 1Dx1D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
# y_block.add_(y_partial, alpha=scale.item())
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
elif not is_a_1d_scaled and is_b_1d_scaled:
# 2Dx1D
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
y_partial = y_partial * sw_block
fused_fma(
y_block,
y_partial,
sx_block.expand_as(y_partial).contiguous(),
)
elif is_a_1d_scaled and not is_b_1d_scaled:
# 1Dx2D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
else:
scale = sx_block * sw_block
fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous())
y = y.to(out_dtype)
return y
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import dataclasses
import math
import torch
from typing import Optional, Protocol, Tuple
from references.quantize_scale_calc import scale_from_amax_tensor
@dataclasses.dataclass()
class QuantizeResult:
data: torch.Tensor
scale: torch.Tensor
data_t: Optional[torch.Tensor]
scale_t: Optional[torch.Tensor]
@dataclasses.dataclass()
class CuBLASScaleMunger:
def munge_scale_shapes_for_backend(
self,
unmunged: QuantizeResult,
tile_shape: Tuple[int, int],
) -> QuantizeResult:
"""
cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed
so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4))
For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4))
format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required
"""
def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor:
if transpose:
s = s.transpose(-1, -2).contiguous()
M, K = s.shape
if K % 4 == 0:
return s
k_pad = 4 - (K % 4)
return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous()
s = _pad_inner_to_align(unmunged.scale, transpose=tile_shape[0] == 1)
if unmunged.scale_t is None:
s_t = None
else:
s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1)
return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t)
@classmethod
def demunge_scale_shape_from_backend(
cls,
qtensor_shape: Tuple[int, int],
scales: torch.Tensor,
tile_shape: Tuple[int, int],
) -> torch.Tensor:
"""
Inverse operation of munge_scale_shapes_for_backend
"""
if tile_shape[0] != 1:
# 2D block quantized tensor may need padding stripped off
derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1])
else:
derived_scale_k_shape = qtensor_shape[0]
M, K = scales.shape
if derived_scale_k_shape != K:
scales = scales[:, :derived_scale_k_shape].contiguous()
if tile_shape[0] == 1:
return scales.transpose(-1, -2).contiguous()
else:
return scales
@dataclasses.dataclass()
class BlockwiseQuantizerReference:
"""
A reference QuantizeOp for subchannel/block hybrid quantization.
Defers to ref GEMMs and quantizization formatting based on the backend.
"""
def __init__(self) -> None:
self.scale_munger = CuBLASScaleMunger()
@classmethod
def _quantize_square_block_tiling(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
return_transpose: bool,
pow_2_scales: bool,
eps: float,
) -> QuantizeResult:
M, K = x.shape
pad_m_k = [0, 0]
if K % tile_len != 0:
pad_m_k[1] = tile_len - (K % tile_len)
if M % tile_len != 0:
pad_m_k[0] = tile_len - (M % tile_len)
unpadded_m, unpadded_k = M, K
if pad_m_k[0] != 0 or pad_m_k[1] != 0:
x = torch.nn.functional.pad(
x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0
).contiguous()
M, K = x.shape
x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len)
amax_grid = (
torch.abs(x_tiled.transpose(-3, -2))
.reshape(M // tile_len, K // tile_len, tile_len**2)
.amax(dim=-1)
).float()
dtype_max = torch.finfo(quant_dtype).max
scale, scale_inv, _ = scale_from_amax_tensor(
x_dtype=x.dtype,
amax=amax_grid,
quant_dtype=quant_dtype,
pow_2_scales=pow_2_scales,
eps=eps,
)
qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
if unpadded_k != K or unpadded_m != M:
qx = qx[:unpadded_m, :unpadded_k].contiguous()
if return_transpose:
# Valid because of square block sizes
qx_t = qx.transpose(-1, -2).contiguous()
scale_inv_t = scale_inv.transpose(-1, -2).contiguous()
else:
qx_t = None
scale_inv_t = None
return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t)
@classmethod
def _quantize_vectorwise_reference(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
pow_2_scales: bool,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
M, K = x.shape
dtype_max = torch.finfo(quant_dtype).max
x_tiled = x.reshape(M, K // tile_len, tile_len)
amax_grid = torch.abs(x_tiled).amax(dim=-1).float()
scale, scale_inv, _ = scale_from_amax_tensor(
x_dtype=x.dtype,
amax=amax_grid,
quant_dtype=quant_dtype,
pow_2_scales=pow_2_scales,
eps=eps,
)
qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
return qx, scale_inv
@classmethod
def _quantize_vector_tiling(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
tile_len: int,
*,
return_transpose: bool,
pow_2_scales: bool,
eps: float,
) -> QuantizeResult:
M, K = x.shape
if K % tile_len == 0:
qref_input = x
else:
pad_amount = tile_len - (K % tile_len)
pad = (0, pad_amount)
qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0)
qout_padded, scale_inv = cls._quantize_vectorwise_reference(
qref_input,
quant_dtype,
tile_len=tile_len,
pow_2_scales=pow_2_scales,
eps=eps,
)
if K % tile_len == 0:
qout = qout_padded
else:
qout = qout_padded[:, :K].contiguous()
if return_transpose:
if M % tile_len == 0:
qref_input = x.transpose(-1, -2).contiguous()
else:
amount_to_pad = tile_len - (M % tile_len)
pad = (0, amount_to_pad)
qref_input = torch.nn.functional.pad(
x.transpose(-1, -2), pad, mode="constant", value=0
).contiguous()
qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference(
qref_input,
quant_dtype,
tile_len=tile_len,
pow_2_scales=pow_2_scales,
eps=eps,
)
if M % tile_len == 0:
qout_t = qout_t_padded
else:
qout_t = qout_t_padded[:, :M].contiguous()
else:
qout_t, scale_inv_t = None, None
return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t)
def ref_dequantize_rowwise(
self,
q: torch.Tensor,
quant_tile_shape: Tuple[int, int],
s: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
assert q.dim() == 2
q_M, q_K = q.shape
s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape)
assert len(s.shape) == 2
m_tiles, k_tiles = s.shape
M, K = q.shape
unpadded_m, unpadded_k = M, K
if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0:
m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0]
k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1]
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0
).contiguous()
M, K = q.shape
q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1])
result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1)
result = result.view(M, K).to(dtype)
if M != unpadded_m or K != unpadded_k:
result = result[:unpadded_m, :unpadded_k].contiguous()
return result
def quantize(
self,
x: torch.Tensor,
quant_dtype: torch.dtype,
return_transpose: bool = False,
eps: float = 0.0,
pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128),
) -> QuantizeResult:
# sanity checks
assert x.dim() == 2
assert x.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
torch.float32,
), "Unsupported input dtype."
assert quant_dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1:
# Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_vector_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[1],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
quant_tile_shape,
)
else:
# Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_square_block_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[0],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
quant_tile_shape,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment