Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
# test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
......
This diff is collapsed.
......@@ -19,169 +19,16 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){
using compute_t = float;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f);
#ifndef __HIP_PLATFORM_AMD__
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
#else
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
#endif
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn) {
NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
......@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) {
// Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend";
}
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
}
// Forward kernel
......@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(false);
nvte_enable_cudnn_norm_bwd(false);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
}
// Reference implementations
......@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
&ref_amax,
ref_scale,
zero_centered_gamma,
use_cudnn);
use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma,
use_cudnn);
use_cudnn,
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -352,6 +219,7 @@ NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool>> {};
TEST_P(NormTestSuite, TestNorm) {
......@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) {
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn);
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
);
);
}
......@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
......@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param));
std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
namespace test {
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (zero_centered_gamma_in_weight_dtype){
compute_t g = static_cast<compute_t>(0.f);
#ifndef __HIP_PLATFORM_AMD__
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
#else
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
#endif
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
if (amax) {
*amax = current_max;
}
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn,
const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
} // namespace
} // namespace test
......@@ -19,6 +19,7 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine;
using namespace test;
......@@ -27,16 +28,6 @@ namespace {
using fp8e8m0 = byte;
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RMSNorm"}
};
template <typename InputType, typename ScaleType, typename OutputType>
void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr,
size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){
......@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
32, 1);
}
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
compute_t m;
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
OutputType* output,
const bool zero_centered_gamma){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1.0;
}
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = tmp;
}
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) {
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training, const bool zero_centered_gamma_in_weight_dtype) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
......@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&gamma);
fillUniform(&beta);
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
// Forward kernel
float epsilon = 1e-5;
if (norm_type == NormType::LayerNorm){
......@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
0);
}
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
......@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(),
ref_output.get(),
ref_mu_ptr,
ref_rsigma_ptr,
N, H,
ref_output.get(),
zero_centered_gamma);
nullptr, // amax
1.f, // scale
zero_centered_gamma,
true, // CuDNN is the only MXFP8 backend currently
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool, bool>> {};
bool, bool, bool>> {};
TEST_P(MxNormTestSuite, TestMxNorm) {
using namespace transformer_engine;
......@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
const auto size = std::get<3>(GetParam());
const bool zero_centered_gamma = std::get<4>(GetParam());
const bool is_training = std::get<5>(GetParam());
const bool zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training);
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training, zero_centered_gamma_in_weight_dtype);
);
);
}
......@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(true, false),
::testing::Values(true, false),
::testing::Values(true, false)),
[](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) {
std::string name = normToString.at(std::get<0>(info.param)) + "_" +
......@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<3>(info.param).first) + "X" +
std::to_string(std::get<3>(info.param).second) + "X" +
std::to_string(std::get<4>(info.param)) + "out" +
std::to_string(int(std::get<5>(info.param)) + 1) + "x";
std::to_string(int(std::get<5>(info.param)) + 1) + "x" +
std::to_string(std::get<6>(info.param));
return name;
});
......@@ -10,6 +10,7 @@
#include <algorithm>
#include <memory>
#include <random>
#include <iostream>
#include <cassert>
#include <cmath>
#include <string>
......@@ -111,8 +112,8 @@ struct scale_inv_meta {
size_t type_size;
};
NVTEShape convertShape(const std::vector<size_t>& shape) {
return {shape.data(), shape.size()};
NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size());
}
std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
......@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul,4ul};
auto block_alignment = std::vector<size_t>{128ul, 4ul};
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim,
static_cast<size_t>(1)),
alignment) * alignment;
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim,
static_cast<size_t>(32)),
alignment) * alignment;
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim,
static_cast<size_t>(32)),
alignment) * alignment;
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim,
static_cast<size_t>(1)),
alignment) * alignment;
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat8E8M0;
......@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
return {ret_rowwise, ret_colwise};
}
NVTE_ERROR("Invalid scaling mode!");
}
......@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0};
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
......@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name,
}
if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data();
columnwise_shape.ndim = columnwise_shape_vec.size();
columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
}
tensor_ = TensorWrapper(scaling_mode);
......@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name,
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape,
tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) {
cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape);
auto scale_dtype = rowwise_scale_meta.type;
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
}
if (columnwise) {
cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*)
cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape);
auto scale_dtype = colwise_scale_meta.type;
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
}
}
}
......@@ -311,7 +357,8 @@ void Tensor::to_cpu() const {
sizeof(float),
cudaMemcpyDeviceToHost);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
......@@ -349,7 +396,8 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
......@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) {
if (columnwise_) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1){
if (num_scales == 1) {
rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{
} else {
std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){
auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_);
}
}
}
if (columnwise_) {
auto num_scales = product(colwise_scale_meta.shape);
if (num_scales == 1){
if (num_scales == 1) {
columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
} else{
} else {
std::uniform_int_distribution<uint8_t> dis(0, 127);
auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++){
auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
for (size_t i = 0; i < num_scales; i++) {
scale_inv_ptr[i] = dis(gen_);
}
}
......@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr,
static_cast<DType>(my_rowwise_data.dtype),
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
my_rowwise_data.shape);
auto my_columnwise_data = tensor_.get_columnwise_data();
new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
static_cast<DType>(my_columnwise_data.dtype),
my_columnwise_data.shape);
auto other_amax = other.tensor_.get_amax();
new_tensor.set_amax(other_amax.data_ptr,
static_cast<DType>(other_amax.dtype),
new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
other_amax.shape);
auto other_scale = other.tensor_.get_scale();
new_tensor.set_scale(other_scale.data_ptr,
static_cast<DType>(other_scale.dtype),
new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
other_scale.shape);
auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
......@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) {
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
std::vector<size_t> ret;
size_t current_i = i;
for (size_t current = shape.ndim - 1;
current > 0;
--current) {
for (size_t current = shape.ndim - 1; current > 0; --current) {
ret.push_back(current_i % shape.data[current]);
current_i /= shape.data[current];
}
......@@ -812,8 +857,7 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
int32_t getDeviceComputeCapability()
{
int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
return 10 * deviceProp.major + deviceProp.minor;
......
......@@ -121,7 +121,7 @@ class Tensor {
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}
Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {}
Tensor() {}
......@@ -148,25 +148,19 @@ class Tensor {
if (scale_inv != nullptr) {
cudaFree(scale_inv);
}
if (columnwise_data_ptr != nullptr){
if (columnwise_data_ptr != nullptr) {
cudaFree(columnwise_data_ptr);
}
if (columnwise_scale_inv != nullptr){
if (columnwise_scale_inv != nullptr) {
cudaFree(columnwise_scale_inv);
}
}
NVTETensor data() const noexcept {
return tensor_.data();
}
NVTETensor data() const noexcept { return tensor_.data(); }
NVTEShape rowwise_shape() const noexcept {
return tensor_.get_rowwise_data().shape;
}
NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; }
NVTEShape columnwise_shape() const noexcept {
return tensor_.get_columnwise_data().shape;
}
NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; }
NVTEShape rowwise_scale_inv_shape() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
......@@ -233,6 +227,8 @@ class Tensor {
T *rowwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -244,6 +240,8 @@ class Tensor {
T *columnwise_cpu_scale_inv_ptr(){
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100;
} // namespace test
......
......@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target is deprecated
ignore:jax.extend.ffi.ffi_lowering is deprecated
This diff is collapsed.
......@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
def impl_test_self_attn(
self,
device_count,
mesh_shape,
......@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0
is_training = True
......@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
)
runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
use_shardy=True,
)
class TestDistributedCrossAttn:
......@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner.test_backward()
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]
class TestDistributedContextParallelSelfAttn:
def impl_test_context_parallel_attn(
......@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
cp_strategy,
use_shardy,
use_scan_ring=False,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
if use_scan_ring:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
jax.config.update("jax_use_shardy_partitioner", use_shardy)
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
......@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_allgather_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.ALL_GATHER,
use_shardy=True,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
def test_context_parallel_allgather_attn(
self,
device_count,
......@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
@pytest.mark.parametrize(
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
......@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced,
use_scan,
):
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
CPStrategy.RING,
use_shardy=False,
use_scan_ring=use_scan,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_ring_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
return
class TestReorderCausalLoadBalancing:
......
......@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
}
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
......@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm(
self,
device_count,
......@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma,
shard_weights,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6
ln_type = "layernorm"
q_dtype = jnp.float8_e4m3fn
......@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_rmsnorm(
self,
device_count,
......@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype,
shard_weights,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6
ln_type = "rmsnorm"
q_dtype = jnp.float8_e4m3fn
......
......@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
......@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
......@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
......@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype)
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else:
b1 = None
......@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else:
layernorm_input_axes = None
dot_1_input_axes = None
dot_2_input_axes = None
dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
......@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type,
quantizer_sets=quantizer_sets,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm"
......@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp"))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close",
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy=False,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
use_shardy=True,
)
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm"
......@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP:
activations=activation_type,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
......@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
scale_axes=LN_SCALE_AXES,
ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
bias_axes_1=BIAS_1_AXES,
bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
......@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=False,
)
......@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(batch, sqelen)
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
if not bad_sharding:
x_pspec = PartitionSpec(
......@@ -45,6 +47,10 @@ class TestDistributedSoftmax:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
if broadcast_batch_mask:
mask_pspec = PartitionSpec(None, None, None, None)
else:
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
......@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
def impl_test_softmax(
self,
device_count,
mesh_shape,
......@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
f"{str(w)}"
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True,
)
......@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
......@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: None,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs26
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs27
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: None,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -313,7 +354,7 @@ class BaseRunner:
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
)
......@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
mask_shape = (batch, 1, seqlen, seqlen)
padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
......
......@@ -18,6 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.flax.module import Softmax
def catch_unsupported(method):
......@@ -94,7 +95,6 @@ class SoftmaxRunner:
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
......@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
......@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
class SoftmaxPrimitivesRunner(SoftmaxRunner):
"""
Jax Softmax Primitives runner
"""
@catch_unsupported
def test_forward(self):
return super().test_forward()
@catch_unsupported
def test_backward(self):
return super().test_backward()
class SoftmaxModuleRunner:
"""
Jax Softmax Module runner
"""
module_runner: SoftmaxRunner
bias: None
def __init__(self, module_runner, bias):
self.module_runner = module_runner
self.bias = bias
def test_forward(self):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner = self.module_runner
runner._setup_inputs()
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
assert_allclose(module_out, reference_out, dtype=runner.dtype)
# Run softmax primitives test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
......@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmax:
class TestSoftmaxPrimitives:
"""
Test transformer_engine.jax.softmax.softmax
"""
......@@ -175,7 +218,7 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward()
@staticmethod
......@@ -183,5 +226,48 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward()
# Run Softmax module test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest.param(8, 511, 512, 16, id="8-511-512-16"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmaxModule:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
......@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
def _get_raw_data(quantized_tensor):
......@@ -228,6 +232,273 @@ class MiniOptimizer:
weight.data.copy_(master_weight)
class MiniFSDP:
def __init__(self, weights, lr, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
tensor_indices = []
cumulative_length = 0
for tensor in raw_data_list:
length = tensor.size(0)
tensor_indices.append((cumulative_length, cumulative_length + length))
cumulative_length += length
# Build shard index mappings
self.weight_indices = []
self.shard_indices = []
for idx, (start, end) in enumerate(tensor_indices):
shard_start = rank * shard_size
shard_end = shard_start + shard_size
adjusted_end = min(shard_end, original_length)
if start <= adjusted_end and end >= shard_start:
start_idx = max(start, shard_start)
end_idx = min(end, adjusted_end)
self.weight_indices.append((start_idx - start, end_idx - start))
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
else:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
shard_start, shard_end = self.shard_indices[i]
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight_shard = high_precision_init_val.to(weight.device).float()[
weight_start:weight_end
]
else:
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
)
def _flatten_tensors_with_pad(self, tensors):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size = dist.get_world_size(self.dp_group)
flatten_tensor = torch.cat(tensors)
original_length = flatten_tensor.size(0)
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
flatten_tensor = torch.cat(
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
)
return flatten_tensor, original_length
def zero_grad(self):
for weight in self.weights:
weight.grad = None
weight.main_grad.zero_()
def step(self):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
self.weights, self.master_weights, self.shard_indices
):
if master_weight is None:
continue
# Extract the local gradient shard for this weight
grad = self.local_main_grad_shard[shard_start:shard_end]
# Update the master weight using gradient descent
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
[idx[0] for idx in self.weight_indices],
self.dp_group,
local_weights,
)
else:
for weight, master_weight in zip(self.local_weights, self.master_weights):
if master_weight is None:
continue
# Copy updated master weights to local weights
weight.data.copy_(master_weight)
# Step 4: All-gather updated weights across processes
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
# Configuration constants
NUM_STEPS = 100
SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": False,
}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank = dist.get_rank(dp_group)
......@@ -389,6 +660,7 @@ def main(argv=None, namespace=None):
dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
dist.destroy_process_group()
return 0
......
......@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
......@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe:
return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe()
......@@ -86,7 +89,7 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 32
BATCH_SIZE = 32
......@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN(output_distributed, target).backward()
def _loss_backward_dw(model_single_node, model_distributed):
model_single_node.backward_dw()
model_distributed.backward_dw()
def _alloc_main_grad(model_single_node, model_distributed):
for model in [model_single_node, model_distributed]:
for param in model.parameters():
......@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -492,6 +504,7 @@ def test_linear():
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column", "row"]:
......@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -665,6 +682,7 @@ def test_layernorm_linear():
{"params_dtype": torch.float16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
......@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -769,6 +790,7 @@ def test_layernorm_mlp():
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
......
......@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......@@ -48,7 +51,7 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"])
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -56,4 +59,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128):
pid = tl.program_id(0)
idx = pid * BLOCK + tl.arange(0, BLOCK)
mask = idx < M * N
row = idx // N
col = idx % N
y_offset = row * y_str0 + col * y_str1
x_offset = row * N + col
s_offset = row * N + col
y = tl.load(y_ptr + y_offset, mask=mask)
x = tl.load(x_ptr + x_offset, mask=mask)
s = tl.load(s_ptr + s_offset, mask=mask)
tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask)
def fused_fma(y, x, s, BLOCK=128):
"""
Fused multiply-add operation (y = y + x * s).
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
This function also supports cases where 'y' is non-contiguous in memory.
"""
assert (
y.shape == x.shape == s.shape and y.dim() == 2
), "All tensors must be 2D with the same shape"
assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous"
M, N = y.shape
grid = ((M * N + BLOCK - 1) // BLOCK,)
fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK)
return y
class CuBLASRefBlockwiseGemm:
"""
A cuBLAS compatible reference implementation of subchannel GEMM.
"""
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
demunged_sx: torch.Tensor,
demunged_sw: torch.Tensor,
quant_tile_shape_x: Tuple[int, int],
quant_tile_shape_w: Tuple[int, int],
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
use_split_accumulator: bool = False,
) -> torch.Tensor:
# demunge scale shapes for cuBLAS
is_a_1d_scaled = quant_tile_shape_x[0] == 1
is_b_1d_scaled = quant_tile_shape_w[0] == 1
M, K = qx.shape
N, K = qw.shape
# mm_tile_shape = (tile_m, tile_n, tile_k)
mm_tile_shape = (
quant_tile_shape_x[0],
quant_tile_shape_w[0],
quant_tile_shape_w[1],
)
if bias is not None and bias.numel():
# To match cuBLAS more closely when bias is applied,
# the reference accumulates into float32, and cast to
# bfloat16 is deferred until after the GEMM.
out_dtype_for_ref = torch.float32
else:
out_dtype_for_ref = out_dtype
y = self.qgemm_blockwise_2d(
qx,
qw,
out_dtype_for_ref,
demunged_sx,
demunged_sw,
mm_tile_shape,
use_split_accumulator,
is_a_1d_scaled,
is_b_1d_scaled,
)
if bias is not None and bias.numel():
y += bias
y = y.to(dtype=out_dtype)
# cublas accumulation first convert to output dtype, then accumulate.
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y
@classmethod
def qgemm_blockwise_2d(
cls,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
mm_tile_shape: Tuple[int, int, int],
use_split_accumulator: bool,
is_a_1d_scaled: bool,
is_b_1d_scaled: bool,
) -> torch.Tensor:
"""
Difference between cuBLAS and CUTLASS GEMM implementations:
- cuBLAS accumulation equation: use different equation for each scaling mode.
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
"""
M, K = qx.shape
N, K_w = qw.shape
assert K == K_w, "K dimension mismatch between qx and qw"
tile_len = 128
# Calculate grid sizes without padding
grid_m = (M + tile_len - 1) // tile_len
grid_n = (N + tile_len - 1) // tile_len
grid_k = (K + tile_len - 1) // tile_len
block_m, block_n, block_k = mm_tile_shape
scale_m_per_tile = tile_len // block_m
scale_n_per_tile = tile_len // block_n
assert block_k == tile_len, "block_k must be equal to tile_len"
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
# Validate shapes of sx and sw
scale_m_per_tensor = (M + block_m - 1) // block_m
scale_n_per_tensor = (N + block_n - 1) // block_n
assert sx.shape == (
scale_m_per_tensor,
grid_k,
), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}"
assert sw.shape == (
scale_n_per_tensor,
grid_k,
), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}"
for i in range(grid_m):
m_start = i * tile_len
m_end = min(m_start + tile_len, M)
m_size = m_end - m_start
for j in range(grid_n):
n_start = j * tile_len
n_end = min(n_start + tile_len, N)
n_size = n_end - n_start
y_block = y[m_start:m_end, n_start:n_end]
for k in range(grid_k):
k_start = k * tile_len
k_end = min(k_start + tile_len, K)
k_size = k_end - k_start
qx_block = (
qx[m_start:m_end, k_start:k_end].clone().contiguous()
) # Shape: [m_size, k_size]
qw_block = (
qw[n_start:n_end, k_start:k_end].clone().contiguous()
) # Shape: [n_size, k_size]
# Extract scaling factors for the current blocks
sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze(
-1
)
sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0)
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
y_partial = torch._scaled_mm(
qx_block,
qw_block.t(),
scale_a=one,
scale_b=one,
out_dtype=torch.float32,
use_fast_accum=not use_split_accumulator,
)
# Accumulate the partial result
if is_a_1d_scaled and is_b_1d_scaled:
# 1Dx1D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
# y_block.add_(y_partial, alpha=scale.item())
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
elif not is_a_1d_scaled and is_b_1d_scaled:
# 2Dx1D
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
y_partial = y_partial * sw_block
fused_fma(
y_block,
y_partial,
sx_block.expand_as(y_partial).contiguous(),
)
elif is_a_1d_scaled and not is_b_1d_scaled:
# 1Dx2D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
else:
scale = sx_block * sw_block
fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous())
y = y.to(out_dtype)
return y
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment