Unverified Commit 3102fdd1 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[C] Normalization Refactor + Adding CUDNN backend (#1315)



* cuDNN normalization integration
* TE Norm refactor
* TE Norm APIs changes.

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d8b13cb0
......@@ -10,8 +10,7 @@ add_executable(test_operator
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_act.cu
test_layernorm.cu
test_rmsnorm.cu
test_normalization.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
......
......@@ -10,12 +10,13 @@
#include <iomanip>
#include <iostream>
#include <random>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
......@@ -24,44 +25,93 @@ using namespace test;
namespace {
enum NormType {
LayerNorm,
RMSNorm
};
std::map<NormType, std::string> normToString = {
{NormType::LayerNorm, "LayerNorm"},
{NormType::RMSNorm, "RmsNorm"}
};
template <typename InputType>
void compute_ref_stats(const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon) {
void compute_ref_stats(NormType norm_type,
const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon){
using compute_t = float;
for (size_t i = 0 ; i < N; ++i) {
compute_t current, m;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += current;
sum += static_cast<compute_t>(data[i * H + j]);
}
if (norm_type == LayerNorm){
mu[i] = sum / H;
compute_t m = mu[i];
sum = 0;
m = mu[i];
} else { m = 0;}
compute_t sum_sq = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += (current - m) * (current - m);
current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m);
}
rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){
using compute_t = float;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
} else {
if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f);
}
g = static_cast<compute_t>(gi);
return g;
} else {
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
return g;
}
sum = sum / H;
compute_t rs = rsqrtf(sum + epsilon);
rsigma[i] = rs;
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta,
OutputType *output, const float *mu, const float *rsigma,
void compute_ref_output(NormType norm_type,
const InputType *data, const InputType *gamma, const InputType *beta,
OutputType* output,
const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale, const bool zero_centered_gamma) {
float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0 ; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
compute_t tmp;
if (norm_type == LayerNorm) {
tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
} else { // RMSNorm
tmp = current * rsigma[i] * g;
}
compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
......@@ -69,33 +119,34 @@ void compute_ref_output(const InputType *data, const InputType *gamma, const Inp
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data,
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H,
const bool zero_centered_gamma) {
const bool zero_centered_gamma, const bool use_cudnn) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
if (norm_type == LayerNorm) {
dbeta[j] += dz;
mdy += dy;
}
mdyy += dy * y;
}
mdy /= H;
......@@ -104,11 +155,8 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t y = (x - local_mu) * rsigma[i];
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
......@@ -117,14 +165,13 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
}
// Weight grads
for (size_t j = 0; j < H; ++j) {
gamma_grad[j] = static_cast<InputType>(dgamma[j]);
beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
......@@ -150,7 +197,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
Tensor dx({ N, H }, itype);
Tensor dgamma({ H }, wtype);
Tensor dbeta({ H }, wtype);
Tensor workspace, barrier, dgamma_part, dbeta_part;
Tensor workspace_fwd, workspace_bwd;
fillUniform(&input);
fillUniform(&gamma);
......@@ -168,46 +215,67 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);
}
// Forward kernel
float epsilon = 1e-5;
auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
// Backward kernel
auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_function(dz.data(), input.data(),
if (norm_type == LayerNorm){
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype());
bwd_function(dz.data(), input.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype());
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
}
if (use_cudnn){
nvte_enable_cudnn_norm_fwd(false);
nvte_enable_cudnn_norm_bwd(false);
}
// Reference implementations
// use the GPU stats to tighten the tolerances
mu.to_cpu();
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
compute_ref_stats(norm_type, input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(),
compute_ref_output(norm_type, input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(),
ref_output.get(),
......@@ -216,12 +284,14 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
N, H,
&ref_amax,
ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
zero_centered_gamma,
use_cudnn);
compute_ref_backward(norm_type, dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H, zero_centered_gamma);
N, H, zero_centered_gamma,
use_cudnn);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -245,58 +315,66 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
}
compareResults("output", z, ref_output.get(), atol, rtol);
double atol_bwd = 1e-4;
double rtol_bwd = 1e-4;
double atol_bwd = 5e-4;
double rtol_bwd = 5e-4;
compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{128, 6144},
{64, 2304},
{229, 541}, // Primes 50, 100
{71, 3571}, // Primes 20, 500
{29, 17389}}; // Primes 10, 2000
std::vector<std::pair<size_t, size_t>> test_cases = {
{71, 229},
{29, 541},
{768, 6144},
{2048, 12288},
};
} // namespace
class LNTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool>> {};
TEST_P(LNTestSuite, TestLN) {
TEST_P(NormTestSuite, TestNorm) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
const bool use_cudnn = std::get<0>(GetParam());
const NormType norm_type = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
LNTestSuite,
NormTestSuite,
::testing::Combine(
::testing::Values(false), //TODO: enabling tests for cudnn backend
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<LNTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
std::string name =
backend +
normToString.at(std::get<1>(info.param)) + "_" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename InputType>
void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H,
const double epsilon) {
using compute_t = float;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += (current) * (current);
}
sum = sum / H;
compute_t rs = rsqrtf(sum + epsilon);
rsigma[i] = rs;
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
const float *rsigma, const size_t N, const size_t H, float *amax,
float scale, const bool zero_centered_gamma) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = current * rsigma[i] * g;
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
const size_t N, const size_t H, const bool zero_centered_gamma) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
for (size_t i = 0; i < N; ++i) {
// Reductions
compute_t mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
mdyy += dy * y;
}
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) {
gamma_grad[j] = static_cast<InputType>(dgamma[j]);
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
return;
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input({N, H}, itype);
Tensor z({N, H}, otype);
Tensor gamma({H}, wtype);
Tensor rsigma({N}, DType::kFloat32);
Tensor dz({N, H}, wtype);
Tensor dx({N, H}, itype);
Tensor dgamma({H}, wtype);
Tensor workspace, barrier, dgamma_part;
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&dz);
setRandomScale(&z);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
// Forward kernel
float epsilon = 1e-5;
auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
// Backward kernel
auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
// Reference implementations
// use the GPU stats to tighten the tolerances
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
ref_dgamma.get(), N, H, zero_centered_gamma);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats);
auto [atol, rtol] = getTolerances(otype);
atol = 1e-8;
compareResults("output", z, ref_output.get(), atol, rtol);
double atol_bwd = 5e-6;
double rtol_bwd = 1e-4;
compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409}, // Primes 40, 80
{71, 3571}, // Primes 20, 500
{29, 17389}}; // Primes 10, 2000
} // namespace
class RMSNormTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool>> {};
TEST_P(RMSNormTestSuite, TestRMSNorm) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);););
}
INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
std::string name =
test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name;
});
......@@ -64,13 +64,14 @@ list(APPEND transformer_engine_SOURCES
fused_attn/thd_utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file layer_norm.h
* \brief LayerNorm functions.
*/
#ifndef TRANSFORMER_ENGINE_LAYER_NORM_H_
#define TRANSFORMER_ENGINE_LAYER_NORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Compute LayerNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute LayerNorm with zero-centered gamma on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[out] dbeta_part Storage for partial bias gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
/*! \brief Compute backward of LayerNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[out] dbeta_part Storage for partial bias gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_LAYER_NORM_H_
......@@ -4,12 +4,12 @@
* See LICENSE for license information.
************************************************************************/
/*! \file rmsnorm.h
* \brief RMSNorm functions.
/*! \file normalization.h
* \brief LayerNorm and RMSNorm functions.
*/
#ifndef TRANSFORMER_ENGINE_RMSNORM_H_
#define TRANSFORMER_ENGINE_RMSNORM_H_
#ifndef TRANSFORMER_ENGINE_NORMALIZATION_H_
#define TRANSFORMER_ENGINE_NORMALIZATION_H_
#include "transformer_engine.h"
......@@ -17,41 +17,73 @@
extern "C" {
#endif
/*! \brief Compute RMSNorm on the input.
/*! \brief Compute LayerNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
* Calling this function with workspace set to empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[out] workspace Workspace tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Compute backward of LayerNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
* else
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* but instead set the shape and type of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu,
const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx,
NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream);
/*! \brief Compute RMSNorm with zero-centered gamma on the input.
/*! \brief Compute RMSNorm.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
......@@ -68,14 +100,14 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon,
NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Compute backward of RMSNorm.
*
......@@ -100,53 +132,25 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Compute backward of RMSNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
/*! \brief Helper to enable cuDNN backend for normalization
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in] bool Enable if True
*/
void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_RMSNORM_H_
#endif // TRANSFORMER_ENGINE_NORMALIZATION_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "../common.h"
namespace transformer_engine {
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params>
struct LaunchParams {
size_t workspace_bytes;
size_t barrier_size;
int multiprocessorCount;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0),
rows(0),
cols(0),
x(nullptr),
mu(nullptr),
rs(nullptr),
gamma(nullptr),
workspace(nullptr),
barrier(nullptr),
zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Size of CTA group.
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
// Scaling factor
void *scale;
// AMax output
void *amax;
// Inverse of scaling factor
void *scale_inv;
// Whether to compute scale and amax
bool fp8_out;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase(),
dz(nullptr),
dbeta_part(nullptr),
dgamma_part(nullptr),
dx(nullptr),
dbeta(nullptr),
dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
using FwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, FwdFunction>>;
using BwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, BwdFunction>>;
extern FwdTunedRegistry FWD_TUNED_FUNCS;
extern BwdTunedRegistry BWD_TUNED_FUNCS;
extern FwdGeneralRegistry FWD_GENERAL_FUNCS;
extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct TypeId {};
template <>
struct TypeId<fp16> {
constexpr static uint32_t Value = 0;
};
template <>
struct TypeId<bf16> {
constexpr static uint32_t Value = 1;
};
template <>
struct TypeId<fp32> {
constexpr static uint32_t Value = 2;
};
template <>
struct TypeId<fp8e4m3> {
constexpr static uint32_t Value = 3;
};
template <typename T, int S>
struct Type2Key {
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
template <typename T>
struct WeightType2Key : public Type2Key<T, 0> {};
template <typename T>
struct InputType2Key : public Type2Key<T, 2> {};
template <typename T>
struct OutputType2Key : public Type2Key<T, 4> {};
template <typename T>
struct ComputeType2Key : public Type2Key<T, 6> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C>
struct Types2Key {
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size) {
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/layer_norm.h>
#include <cstdint>
#include <vector>
#include "../common.h"
#include "ln.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
bf16 fp32 bf16 fp8
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
// Create registries and provide runtime versions of config hash functions.
FwdTunedRegistry FWD_TUNED_FUNCS;
BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(DType dtype) {
if (dtype == DType::kFloat16) {
return TypeId<fp16>::Value;
} else if (dtype == DType::kBFloat16) {
return TypeId<bf16>::Value;
} else if (dtype == DType::kFloat32) {
return TypeId<fp32>::Value;
} else if (dtype == DType::kFloat8E4M3) {
return TypeId<fp8e4m3>::Value;
} else {
NVTE_ERROR("Type not supported.");
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) |
(get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::FwdParams& params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void* ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) &&
is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::BwdParams& params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void* ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) &&
is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) &&
is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) &&
layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
size_t product(const std::vector<size_t>& shape) {
size_t ret = 1;
for (auto s : shape) {
ret *= s;
}
return ret;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream,
const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
const bool zero_centered_gamma) {
const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
layer_norm::FwdParams& params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu->data.dptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = beta.data.dptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma,
const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta,
Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream,
const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
auto rows = x.data.shape[0];
auto cols = x.data.shape[1];
auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
layer_norm::BwdParams& params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu.data.dptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
NVTE_CHECK(dbeta_part->data.dptr == nullptr);
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
dbeta_part->data.dtype = ctype;
dbeta_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(dbeta_part->data.dptr != nullptr);
auto pdw_shape =
std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
NVTE_CHECK(dbeta_part->data.dtype == ctype);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), false);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), false);
}
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), true);
}
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), true);
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/* #include <transformer_engine/layer_norm.h> */
#include "common.h"
#include <bitset>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include "transformer_engine/normalization.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
bf16 fp32 bf16 fp8
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace transformer_engine {
namespace normalization {
TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size,
bool zero_centered_gamma, bool is_tuned) {
uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) |
(static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) |
(uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 |
(uint32_t(zero_centered_gamma) << 16);
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}
template <typename KernelParamsType>
TeNormalizationPlan<KernelParamsType>::TeNormalizationPlan(
NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype,
DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count,
const bool zero_centered_gamma, const bool is_tuned)
: _is_layernorm(NormType == NVTE_Norm_Type::LayerNorm) {
_launch_params.multiprocessorCount = sm_count;
auto& kernel_params = _launch_params.params;
kernel_params.rows = batch_size;
kernel_params.cols = hidden_size;
kernel_params.zero_centered_gamma = zero_centered_gamma;
if constexpr (std::is_same_v<KernelParamsType, ForwardKernelParams>) {
kernel_params.fp8_out = is_fp8_dtype(otype);
}
// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those
auto key =
get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned);
_kernel = KernelRegistry::getKernel(key);
this->_build();
}
template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
_launch_params.stream = stream;
auto& kernel_params = _launch_params.params;
kernel_params.workspace = workspace_dptr;
kernel_params.x = x_dptr;
kernel_params.rs = rsigma_dptr;
kernel_params.gamma = gamma_dptr;
kernel_params.z = z->data.dptr;
kernel_params.epsilon = *reinterpret_cast<float*>(eps_dptr);
kernel_params.amax = z->amax.dptr;
kernel_params.scale = z->scale.dptr;
kernel_params.scale_inv = z->scale_inv.dptr;
if (_is_layernorm) {
kernel_params.mu = mean_dptr;
kernel_params.beta = beta_dptr;
}
_set_workspace();
_kernel(_launch_params, false);
}
template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Backward normalization should not call the forward execute function!");
}
template <typename KernelParamsType>
void TeNormalizationPlan<KernelParamsType>::_build() {
_kernel(_launch_params, true);
_launch_params.alignWorkspace();
}
template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)};
}
template <typename KernelParamsType>
void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
if (_launch_params.getTotalWorkspaceBytes() > 0) {
auto workspace_dptr = reinterpret_cast<byte*>(_launch_params.params.workspace);
if (_launch_params.barrier_bytes > 0) {
_launch_params.params.barrier =
reinterpret_cast<int*>(workspace_dptr + _launch_params.workspace_bytes);
cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes,
_launch_params.stream);
}
if constexpr (std::is_same_v<KernelParamsType, BackwardKernelParams>) {
_launch_params.params.dgamma_part =
workspace_dptr + _launch_params.workspace_bytes + _launch_params.barrier_bytes;
if (_is_layernorm) {
_launch_params.params.dbeta_part =
reinterpret_cast<byte*>(_launch_params.params.dgamma_part) +
_launch_params.dgamma_part_bytes;
}
}
}
}
template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Forward normalization should not call the backward execute function!");
}
template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
_launch_params.stream = stream;
auto& kernel_params = _launch_params.params;
kernel_params.workspace = workspace_dptr;
kernel_params.x = x_dptr;
kernel_params.gamma = gamma_dptr;
kernel_params.rs = rsigma_dptr;
kernel_params.dx = dx_dptr;
kernel_params.dz = dz_dptr;
kernel_params.dgamma = dgamma_dptr;
if (_is_layernorm) {
kernel_params.mu = mean_dptr;
kernel_params.dbeta = dbeta_dptr;
}
_set_workspace();
_kernel(_launch_params, false);
}
CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype, DType ctype,
const size_t batch_size, const size_t hidden_size,
const size_t sm_count,
const bool zero_centered_gamma)
: _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) {
static_assert(CUDNN_FRONTEND_VERSION >= 10601,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!");
namespace fe = cudnn_frontend;
_scalar_dptr = std::make_unique<char[]>(typeToSize(wtype));
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
_handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
_graph.set_io_data_type(get_cudnn_fe_dtype(itype))
.set_intermediate_data_type(get_cudnn_fe_dtype(ctype))
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
if (cudnnGetVersion() >= 90400) _graph.set_sm_count(sm_count);
const auto batch_dim = static_cast<int32_t>(batch_size);
const auto hidden_dim = static_cast<int32_t>(hidden_size);
// Create graph tensors
_x = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("X")
.set_dim({batch_dim, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(itype)));
_gamma_zero = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("gamma_zero")
.set_dim({1, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(wtype)));
if (zero_centered_gamma) {
_scalar_offset = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("one")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(wtype))
.set_is_pass_by_value(true));
auto centered_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::ADD)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
_gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options);
_gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype));
} else {
_gamma = _gamma_zero;
}
// Create graph computation nodes
if (NormStage == NVTE_Norm_Stage::Forward) {
_eps = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("epsilon")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype))
.set_is_pass_by_value(true));
if (NormType == NVTE_Norm_Type::LayerNorm) {
_beta = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(wtype)));
auto norm_options = fe::graph::Layernorm_attributes()
.set_forward_phase(fe::NormFwdPhase_t::TRAINING)
.set_epsilon(_eps)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options);
std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]);
_mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
} else if (NormType == NVTE_Norm_Type::RMSNorm) {
auto norm_options = fe::graph::Rmsnorm_attributes()
.set_forward_phase(fe::NormFwdPhase_t::TRAINING)
.set_epsilon(_eps)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto ret = _graph.rmsnorm(_x, _gamma, norm_options);
std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]);
}
_rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
const auto ZDtype = _fp8_out ? ctype : otype;
_z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype));
if (_fp8_out) {
// create a scale node
_z_scale = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("z_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype)));
auto z_scale_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::MUL)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
_z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options);
_z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
// create an amax reduction node
_amax = _graph.reduction(_z, fe::graph::Reduction_attributes()
.set_mode(fe::ReductionMode_t::AMAX)
.set_compute_data_type(get_cudnn_fe_dtype(ctype)));
_amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1});
}
} else {
_dz = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("dz")
.set_dim({batch_dim, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim}));
_rsigma = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("inv_var")
.set_dim({batch_dim, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype)));
_mean = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("mean")
.set_dim({batch_dim, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype)));
if (NormType == NVTE_Norm_Type::LayerNorm) {
auto norm_options = fe::graph::Layernorm_backward_attributes()
.set_saved_mean_and_inv_variance(_mean, _rsigma)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto ret = _graph.layernorm_backward(_dz, _x, _gamma, norm_options);
std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
_dbeta->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
} else {
auto norm_options =
fe::graph::Rmsnorm_backward_attributes().has_dbias(false).set_compute_data_type(
get_cudnn_fe_dtype(ctype));
auto ret = _graph.rmsnorm_backward(_dz, _x, _gamma, _rsigma, norm_options);
std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned.");
}
_dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
_dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
}
// Build the graph
this->_build();
}
void CudnnNormalizationPlan::_build() {
NVTE_CHECK(_graph.validate().is_good());
NVTE_CHECK(_graph.build_operation_graph(_handle).is_good());
NVTE_CHECK(_graph
.create_execution_plans(
{cudnn_frontend::HeurMode_t::A, cudnn_frontend::HeurMode_t::FALLBACK})
.is_good());
NVTE_CHECK(_graph.check_support(_handle).is_good());
NVTE_CHECK(
_graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good());
}
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
return {static_cast<size_t>(_graph.get_workspace_size())};
}
void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
// Binding data pointers to graph tensors
_variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}};
// layernorm should have valid mean_dptr and beta_dptr
if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}});
if (_zero_centered)
_variant_pack.insert(
{{_scalar_offset, reinterpret_cast<void*>(_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}});
else
_variant_pack.insert({{_gamma, gamma_dptr}});
if (_fp8_out)
_variant_pack.insert(
{{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}});
else
_variant_pack.insert({{_z, z->data.dptr}});
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
if (_fp8_out) update_tensor_scale_inv(z, stream);
}
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
if (_zero_centered)
_variant_pack.insert({{_scalar_offset, reinterpret_cast<void*>(this->_scalar_dptr.get())},
{_gamma_zero, gamma_dptr}});
else
_variant_pack.insert({{_gamma, gamma_dptr}});
// layernorm should have valid mean_dptr and beta_dptr
if (mean_dptr && dbeta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_dbeta, dbeta_dptr}});
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
}
NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) {
const DType ctype = DType::kFloat32;
bool is_tuned = is_aligned && (batch_size % 4 == 0);
auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size,
zero_centered_gamma, is_tuned);
auto it = normalizationPlanMap.find(key);
if (it != normalizationPlanMap.end()) {
return it->second.get();
}
std::unique_ptr<NormalizationPlanBase> plan;
if (NormBackend == NVTE_Norm_Backend::Cudnn) {
plan = std::make_unique<CudnnNormalizationPlan>(NormType, NormStage, wtype, itype, otype, ctype,
batch_size, hidden_size, sm_count,
zero_centered_gamma);
} else if (NormStage == NVTE_Norm_Stage::Forward) {
plan = std::make_unique<TeNormalizationPlan<ForwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned);
} else {
plan = std::make_unique<TeNormalizationPlan<BackwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned);
}
normalizationPlanMap.insert({key, std::move(plan)});
return normalizationPlanMap[key].get();
}
bool& _cudnn_norm_fwd_flag() {
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN");
return flag;
}
bool& _cudnn_norm_bwd_flag() {
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN");
return flag;
}
bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
} // namespace normalization
} // namespace transformer_engine
void nvte_enable_cudnn_norm_fwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_fwd);
transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable;
}
void nvte_enable_cudnn_norm_bwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <tuple>
#include <typeindex>
#include <unordered_map>
#include <vector>
#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/system.h"
namespace transformer_engine {
namespace normalization {
namespace fe = cudnn_frontend;
template <typename KernelParamsType>
struct LaunchParams {
size_t workspace_bytes = 0;
size_t barrier_bytes = 0;
size_t dgamma_part_bytes = 0;
int multiprocessorCount;
cudaStream_t stream;
KernelParamsType params;
size_t getTotalWorkspaceBytes(const bool _is_layernorm = true) const {
return (workspace_bytes + barrier_bytes + size_t(_is_layernorm + 1) * dgamma_part_bytes);
}
void alignWorkspace(size_t alignment = 16) {
workspace_bytes = DIVUP(workspace_bytes, alignment) * alignment;
barrier_bytes = DIVUP(barrier_bytes, alignment) * alignment;
dgamma_part_bytes = DIVUP(dgamma_part_bytes, alignment) * alignment;
}
};
struct KernelParamsBase {
KernelParamsBase()
: ctas_per_col(0),
rows(0),
cols(0),
x(nullptr),
mu(nullptr),
rs(nullptr),
gamma(nullptr),
workspace(nullptr),
barrier(nullptr),
zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Size of CTA group.
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void* x;
void* mu;
void* rs;
void* gamma;
// Multi-CTA workspace in gmem.
void* workspace;
// Multi-CTA sync barriers in gmem.
int* barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
};
struct ForwardKernelParams : public KernelParamsBase {
ForwardKernelParams()
: KernelParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {}
// Output of LN FWD.
void* z;
void* beta;
float epsilon;
// Scaling factor
void* scale;
int scale_byte_size;
// Inverse of scaling factor
void* scale_inv;
// AMax output
void* amax;
int amax_byte_size;
// Whether to compute scale and amax
bool fp8_out;
};
struct BackwardKernelParams : public KernelParamsBase {
BackwardKernelParams()
: KernelParamsBase(),
dz(nullptr),
dbeta_part(nullptr),
dgamma_part(nullptr),
dx(nullptr),
dbeta(nullptr),
dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output.
void* dz;
// Workspace for Wgrad pre-reduction.
void* dbeta_part;
void* dgamma_part;
// Output: Dgrad.
void* dx;
// Output: Wgrad.
void* dbeta;
void* dgamma;
};
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
struct TupleHash {
size_t operator()(const TupleKeyType& t) const {
// Generate a hash for a tuple by combining the hashes of its entries
// See: https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine
size_t seed = 0;
std::hash<uint64_t> hasher;
seed ^= hasher(std::get<0>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(std::get<1>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(std::get<2>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return seed;
}
};
TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size,
bool zero_centered_gamma, bool is_tuned);
template <typename KernelParamsType>
class TeNormalizationRegistry {
private:
using Function = std::function<void(LaunchParams<KernelParamsType>&, const bool)>;
std::unordered_map<TupleKeyType, Function, TupleHash> tuned_function_map;
std::unordered_map<uint64_t, std::map<uint64_t, Function>> general_function_map;
TeNormalizationRegistry() = default;
static TeNormalizationRegistry& getInstance() {
static TeNormalizationRegistry registry;
return registry;
}
public:
static int registerFunction(TupleKeyType key,
void (*func)(LaunchParams<KernelParamsType>&, const bool)) {
auto [general_key, batch_size, hidden_size, is_tuned] = key;
if (is_tuned)
getInstance().tuned_function_map.emplace(key, Function(func));
else
getInstance().general_function_map[general_key].emplace(hidden_size, Function(func));
return 0;
}
static Function getKernel(TupleKeyType key) {
auto& instance = getInstance();
auto [general_key, batch_size, hidden_size, is_tuned] = key;
if (is_tuned) {
auto it = instance.tuned_function_map.find(key);
if (it != instance.tuned_function_map.end()) return it->second;
}
if (instance.general_function_map.count(general_key) == 0) {
NVTE_ERROR("Unavailable kernel for this normalization config.");
}
auto& general_func_map = instance.general_function_map.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
if (func_iter == general_func_map.end()) {
return general_func_map.rbegin()->second; // Hidden size is too big, need to use multi-CTA
} else {
return func_iter->second;
}
}
TeNormalizationRegistry(const TeNormalizationRegistry&) = delete;
TeNormalizationRegistry& operator=(const TeNormalizationRegistry&) = delete;
TeNormalizationRegistry(TeNormalizationRegistry&&) = delete;
TeNormalizationRegistry& operator=(TeNormalizationRegistry&&) = delete;
};
class NormalizationPlanBase {
public:
virtual ~NormalizationPlanBase() = default;
virtual std::vector<size_t> getWorkspaceShape() const = 0;
virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
cudaStream_t stream) = 0;
virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) = 0;
private:
virtual void _build() = 0;
};
template <typename KernelParamsType>
class TeNormalizationPlan : public NormalizationPlanBase {
public:
TeNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
DType otype, DType ctype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned);
std::vector<size_t> getWorkspaceShape() const override;
void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
private:
void _set_workspace();
void _build();
using KernelRegistry = TeNormalizationRegistry<KernelParamsType>;
LaunchParams<KernelParamsType> _launch_params;
std::function<void(LaunchParams<KernelParamsType>&, const bool)> _kernel;
const bool _is_layernorm;
};
class CudnnNormalizationPlan : public NormalizationPlanBase {
public:
CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, DType ctype, const size_t batch_size,
const size_t hidden_size, const size_t sm_count,
const bool zero_centered_gamma);
std::vector<size_t> getWorkspaceShape() const override;
void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
private:
void _build() override;
const bool _zero_centered, _fp8_out;
std::unique_ptr<char[]> _scalar_dptr;
// FWD
std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
_eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8;
// BWD
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta;
fe::graph::Graph _graph;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;
cudnnHandle_t _handle;
};
class NormalizationPlanRegistry {
public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance;
return instance;
}
NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend,
NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype,
const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma,
const bool is_aligned);
private:
NormalizationPlanRegistry() {}
NormalizationPlanRegistry(const NormalizationPlanRegistry&) = delete;
NormalizationPlanRegistry& operator=(const NormalizationPlanRegistry&) = delete;
std::unordered_map<TupleKeyType, std::unique_ptr<NormalizationPlanBase>, TupleHash>
normalizationPlanMap;
};
using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
template <typename T>
struct TypeToDType;
template <>
struct TypeToDType<fp32> {
static constexpr DType value = DType::kFloat32;
};
template <>
struct TypeToDType<fp16> {
static constexpr DType value = DType::kFloat16;
};
template <>
struct TypeToDType<bf16> {
static constexpr DType value = DType::kBFloat16;
};
template <>
struct TypeToDType<fp8e4m3> {
static constexpr DType value = DType::kFloat8E4M3;
};
template <>
struct TypeToDType<fp8e5m2> {
static constexpr DType value = DType::kFloat8E5M2;
};
template <>
struct TypeToDType<int32> {
static constexpr DType value = DType::kInt32;
};
template <>
struct TypeToDType<byte> {
static constexpr DType value = DType::kByte;
};
#define IS_TUNED(x) (strcmp(#x, "tuned") == 0 ? 1 : 0)
// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those
#define REGISTER_NORM_BASE(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, \
CTYPE, FUNC_NAME) \
static int \
register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \
TeNormalizationRegistry<NORM_STAGE##KernelParams>::registerFunction( \
(get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \
(TypeToDType<WTYPE>::value), (TypeToDType<ITYPE>::value), \
(TypeToDType<OTYPE>::value), (TypeToDType<CTYPE>::value), 0, HIDDEN_SIZE, \
0, IS_TUNED(LAUNCH_TYPE))), \
FUNC_NAME)
// For FP8 only
void ComputeScaleInv(void* scale, void* scale_inv);
// Alignment check
template <size_t Alignment = 16, typename... Args>
bool is_ptr_aligned(const Args*... ptrs) {
return ((reinterpret_cast<uintptr_t>(ptrs) % Alignment == 0) && ...);
}
bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd();
} // namespace normalization
} // namespace transformer_engine
#endif
......@@ -4,16 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
#ifndef TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_
#include "../common.h"
#include "../utils.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace transformer_engine {
namespace layer_norm {
namespace normalization {
template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_>
struct Kernel_traits_base {
......@@ -28,8 +27,6 @@ struct Kernel_traits_base {
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
......@@ -67,8 +64,6 @@ struct Kernel_traits_finalize : public Base {
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
......@@ -129,9 +124,7 @@ struct Kernel_traits : public Base {
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
#endif // TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/normalization.h>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <vector>
#include "../../common.h"
#include "../common.h"
namespace transformer_engine {
using namespace normalization;
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]});
NVTE_CHECK(mu->data.dtype == DType::kFloat32);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]});
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32);
if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
}
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_fwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
mu->data.dptr, rsigma->data.dptr);
}
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward,
gamma.data.dtype, // wtype
x.data.dtype, // itype
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr,
reinterpret_cast<void*>(const_cast<float*>(&epsilon)), rsigma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma,
const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta,
Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(dz.data.dtype == gamma.data.dtype);
NVTE_CHECK(mu.data.dtype == DType::kFloat32);
NVTE_CHECK(rsigma.data.dtype == mu.data.dtype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
NVTE_CHECK(mu.data.shape[0] == x.data.shape[0]);
NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
}
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr,
dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr);
}
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Backward,
gamma.data.dtype, // wtype
x.data.dtype, // itype
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream);
}
return;
}
} // namespace transformer_engine
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
}
......@@ -7,16 +7,15 @@
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#include "../utils.cuh"
#include "ln.h"
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
namespace normalization {
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel(
layer_norm::BwdParams params) {
BackwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
......@@ -119,8 +118,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel(
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
mdy_local = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
......@@ -203,7 +202,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel(
template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel(
BwdParams params) {
BackwardKernelParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
......@@ -323,7 +322,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel(
layer_norm::BwdParams params) {
BackwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -424,8 +423,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
mdy = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
......@@ -507,7 +506,7 @@ template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARP
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__
__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel(
layer_norm::BwdParams params) {
BackwardKernelParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
......@@ -573,7 +572,7 @@ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_gener
}
}
} // namespace layer_norm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
......@@ -10,15 +10,16 @@
#include <cfloat>
#include <cstdio>
#include "../utils.cuh"
#include "ln.h"
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace layer_norm {
namespace normalization {
using namespace transformer_engine;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) {
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
ForwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -92,8 +93,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
compute_t mu = Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = Get<1>::of<stats_t, compute_t>(s);
if (bidn == 0 && warp_n == 0 && lane == 0) {
mu_ptr[row] = mu;
......@@ -150,7 +151,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel(
FwdParams params) {
ForwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -315,7 +316,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
} // namespace layer_norm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <vector>
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/normalization.h"
namespace transformer_engine {
using namespace normalization;
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]});
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32);
if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
}
Tensor empty;
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_fwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
}
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward,
gamma.data.dtype, // wtype
x.data.dtype, // itype
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr,
reinterpret_cast<void *>(const_cast<float *>(&epsilon)), rsigma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(dz.data.dtype == gamma.data.dtype);
NVTE_CHECK(rsigma.data.dtype == DType::kFloat32);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
}
Tensor empty;
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr);
}
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Backward,
gamma.data.dtype, // wtype
x.data.dtype, // itype
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream);
}
return;
}
} // namespace transformer_engine
void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
}
......@@ -7,15 +7,15 @@
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#include "../utils.cuh"
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace rmsnorm {
using namespace transformer_engine;
namespace normalization {
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BwdParams params) {
BackwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
......@@ -172,7 +172,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel(
BwdParams params) {
BackwardKernelParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
......@@ -276,7 +276,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BwdParams params) {
BackwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -430,8 +430,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__ __launch_bounds__(
WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) {
__global__
__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(
BackwardKernelParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
......@@ -474,7 +475,7 @@ __global__ __launch_bounds__(
}
}
} // namespace rmsnorm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment