Unverified Commit 61f1bf6f authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

Support computing zero-centered gamma in compute dtype for CuDNN (#1690)



* Add a flag to support computing zero-centered gamma in weight dtype or compute dtype for CuDNN
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a0cabb71
...@@ -18,159 +18,16 @@ ...@@ -18,159 +18,16 @@
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine; using namespace transformer_engine;
using namespace test; using namespace test;
namespace { namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){
using compute_t = float;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn) { NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
if (sizeof(InputType) < sizeof(OutputType)) { if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return; return;
...@@ -219,9 +76,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -219,9 +76,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); cudaGetDeviceProperties(&prop, 0);
if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) {
// Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend";
}
if (use_cudnn){ if (use_cudnn){
nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true); nvte_enable_cudnn_norm_bwd(true);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
} }
// Forward kernel // Forward kernel
...@@ -269,6 +139,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -269,6 +139,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if (use_cudnn){ if (use_cudnn){
nvte_enable_cudnn_norm_fwd(false); nvte_enable_cudnn_norm_fwd(false);
nvte_enable_cudnn_norm_bwd(false); nvte_enable_cudnn_norm_bwd(false);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
} }
// Reference implementations // Reference implementations
...@@ -289,14 +164,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -289,14 +164,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
&ref_amax, &ref_amax,
ref_scale, ref_scale,
zero_centered_gamma, zero_centered_gamma,
use_cudnn); use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(), compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(), mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(), gamma.rowwise_cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma, N, H, zero_centered_gamma,
use_cudnn); use_cudnn,
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -341,6 +218,7 @@ NormType, ...@@ -341,6 +218,7 @@ NormType,
transformer_engine::DType, transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType,
std::pair<size_t, size_t>, std::pair<size_t, size_t>,
bool,
bool>> {}; bool>> {};
TEST_P(NormTestSuite, TestNorm) { TEST_P(NormTestSuite, TestNorm) {
...@@ -353,10 +231,11 @@ TEST_P(NormTestSuite, TestNorm) { ...@@ -353,10 +231,11 @@ TEST_P(NormTestSuite, TestNorm) {
const DType output_type = std::get<3>(GetParam()); const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam()); const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam()); const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
); );
); );
} }
...@@ -370,6 +249,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -370,6 +249,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases), ::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true)), ::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) { [](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
...@@ -380,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -380,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<3>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" + std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param)); std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
return name; return name;
}); });
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
namespace test {
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (zero_centered_gamma_in_weight_dtype){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
if (amax) {
*amax = current_max;
}
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma, const bool use_cudnn,
const bool cudnn_zero_centered_gamma_in_weight_dtype) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
} // namespace
} // namespace test
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
#include "test_normalization.h"
using namespace transformer_engine; using namespace transformer_engine;
using namespace test; using namespace test;
...@@ -27,16 +28,6 @@ namespace { ...@@ -27,16 +28,6 @@ namespace {
using fp8e8m0 = byte; using fp8e8m0 = byte;
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RMSNorm"}
};
template <typename InputType, typename ScaleType, typename OutputType> template <typename InputType, typename ScaleType, typename OutputType>
void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr,
size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){
...@@ -110,65 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) ...@@ -110,65 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
32, 1); 32, 1);
} }
template <typename InputType>
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
sum += static_cast<compute_t>(data[i * H + j]);
}
compute_t m;
if (norm_type == LayerNorm){
mu[i] = sum / H;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
OutputType* output,
const bool zero_centered_gamma){
using compute_t = float;
#pragma omp parallel for proc_bind(spread)
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1.0;
}
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
output[i * H + j] = tmp;
}
}
}
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training, const bool zero_centered_gamma_in_weight_dtype) {
cudaDeviceProp prop; cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); cudaGetDeviceProperties(&prop, 0);
...@@ -195,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -195,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&gamma); fillUniform(&gamma);
fillUniform(&beta); fillUniform(&beta);
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
} else {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
// Forward kernel // Forward kernel
float epsilon = 1e-5; float epsilon = 1e-5;
if (norm_type == NormType::LayerNorm){ if (norm_type == NormType::LayerNorm){
...@@ -220,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -220,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
0); 0);
} }
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training); dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
...@@ -246,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -246,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(), compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
gamma.rowwise_cpu_dptr<WeightType>(), gamma.rowwise_cpu_dptr<WeightType>(),
beta.rowwise_cpu_dptr<WeightType>(), beta.rowwise_cpu_dptr<WeightType>(),
ref_output.get(),
ref_mu_ptr, ref_mu_ptr,
ref_rsigma_ptr, ref_rsigma_ptr,
N, H, N, H,
ref_output.get(), nullptr, // amax
zero_centered_gamma); 1.f, // scale
zero_centered_gamma,
true, // CuDNN is the only MXFP8 backend currently
zero_centered_gamma_in_weight_dtype);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -298,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType, ...@@ -298,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine::DType, transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType,
std::pair<size_t, size_t>, std::pair<size_t, size_t>,
bool, bool>> {}; bool, bool, bool>> {};
TEST_P(MxNormTestSuite, TestMxNorm) { TEST_P(MxNormTestSuite, TestMxNorm) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -310,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) { ...@@ -310,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
const auto size = std::get<3>(GetParam()); const auto size = std::get<3>(GetParam());
const bool zero_centered_gamma = std::get<4>(GetParam()); const bool zero_centered_gamma = std::get<4>(GetParam());
const bool is_training = std::get<5>(GetParam()); const bool is_training = std::get<5>(GetParam());
const bool zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training); performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, is_training, zero_centered_gamma_in_weight_dtype);
); );
); );
} }
...@@ -327,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -327,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases), ::testing::ValuesIn(test_cases),
::testing::Values(true, false), ::testing::Values(true, false),
::testing::Values(true, false),
::testing::Values(true, false)), ::testing::Values(true, false)),
[](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) { [](const testing::TestParamInfo<MxNormTestSuite::ParamType>& info) {
std::string name = normToString.at(std::get<0>(info.param)) + "_" + std::string name = normToString.at(std::get<0>(info.param)) + "_" +
...@@ -335,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -335,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<3>(info.param).first) + "X" + std::to_string(std::get<3>(info.param).first) + "X" +
std::to_string(std::get<3>(info.param).second) + "X" + std::to_string(std::get<3>(info.param).second) + "X" +
std::to_string(std::get<4>(info.param)) + "out" + std::to_string(std::get<4>(info.param)) + "out" +
std::to_string(int(std::get<5>(info.param)) + 1) + "x"; std::to_string(int(std::get<5>(info.param)) + 1) + "x" +
std::to_string(std::get<6>(info.param));
return name; return name;
}); });
...@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
* in weight dtype. If set to false, it will compute in compute dtype.
*
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in] bool Enable if True
*/
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -39,6 +39,8 @@ Compute always in FP32 ...@@ -39,6 +39,8 @@ Compute always in FP32
namespace transformer_engine { namespace transformer_engine {
namespace normalization { namespace normalization {
bool& use_zero_centered_gamma_in_weight_dtype();
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE; : cudnn_frontend::NormFwdPhase_t::INFERENCE;
...@@ -207,9 +209,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -207,9 +209,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_ndim_scale_block = 1; _ndim_scale_block = 1;
} }
_scalar_dptr = std::make_unique<char[]>(typeToSize(wtype)); const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype;
_scalar_dptr = std::make_unique<char[]>(typeToSize(gamma_dtype));
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;); gamma_dtype, cpp_dtype,
*(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
_handle = cudnnExecutionPlanManager::Instance().GetHandle(); _handle = cudnnExecutionPlanManager::Instance().GetHandle();
...@@ -239,13 +244,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -239,13 +244,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.set_name("one") .set_name("one")
.set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(wtype)) .set_data_type(get_cudnn_fe_dtype(gamma_dtype))
.set_is_pass_by_value(true)); .set_is_pass_by_value(true));
auto centered_options = fe::graph::Pointwise_attributes() auto centered_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::ADD) .set_mode(fe::PointwiseMode_t::ADD)
.set_compute_data_type(get_cudnn_fe_dtype(ctype)); .set_compute_data_type(get_cudnn_fe_dtype(ctype));
_gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options);
_gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(gamma_dtype));
} else { } else {
_gamma = _gamma_zero; _gamma = _gamma_zero;
} }
...@@ -503,6 +508,13 @@ bool& _cudnn_norm_bwd_flag() { ...@@ -503,6 +508,13 @@ bool& _cudnn_norm_bwd_flag() {
bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
bool& _zero_centered_gamma_in_weight_dtype() {
static bool flag = transformer_engine::getenv<bool>("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE");
return flag;
}
bool& use_zero_centered_gamma_in_weight_dtype() { return _zero_centered_gamma_in_weight_dtype(); }
} // namespace normalization } // namespace normalization
} // namespace transformer_engine } // namespace transformer_engine
...@@ -515,3 +527,8 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { ...@@ -515,3 +527,8 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
} }
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
}
...@@ -31,19 +31,24 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -31,19 +31,24 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape); NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); NVTE_CHECK(gamma.data.dtype == beta.data.dtype,
"Gamma and Beta must have the same dtype. Gamma dtype: " +
to_string(gamma.data.dtype) + ", Beta dtype: " + to_string(beta.data.dtype));
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size.");
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");
NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(mu->data.dtype == DType::kFloat32); "Mu must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor.");
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
......
...@@ -27,15 +27,16 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -27,15 +27,16 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1], "Gamma must have the same hidden size.");
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
......
...@@ -64,6 +64,27 @@ def get_backward_sm_margin(): ...@@ -64,6 +64,27 @@ def get_backward_sm_margin():
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
@cache
def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether CuDNN norm fwd is enabled."""
# MXFP8_1D_SCALING always uses CuDNN currently
return (
int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1
or scaling_mode == ScalingMode.MXFP8_1D_SCALING
)
@cache
def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma
in weight dtype as opposed to compute dtype."""
if not is_norm_fwd_cudnn_enabled(scaling_mode):
# If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype
# Remove this when TE supports gamma += 1.0 in weight dtype
return False
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
class NormFwdPrimitive(BasePrimitive): class NormFwdPrimitive(BasePrimitive):
""" """
Layer Normalization Forward FP8 Primitive Layer Normalization Forward FP8 Primitive
...@@ -788,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ...@@ -788,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
JAX native layernorm implementation JAX native layernorm implementation
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon) rsigma = jax.lax.rsqrt(var + epsilon)
...@@ -809,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ...@@ -809,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
JAX native rmsnorm implementation JAX native rmsnorm implementation
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon) rsigma = jax.lax.rsqrt(var + epsilon)
normed_input = x_ * rsigma normed_input = x_ * rsigma
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment