Unverified Commit 37cc3625 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Add RMSNorm (#45)



* Add rmsnorm kernels
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rmsnorm cpp unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply new Tensor struct
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move scale/scale_inv/amax into the TE Tensor struct
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add document
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Separate rmsnorm kernels from the layernorm
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* fix indent
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update rmsnorm test cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update copyright year
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the support matrix on the document
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move register macro out of utils.cuh
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 0c9c0ba1
...@@ -10,6 +10,7 @@ add_executable(test_operator ...@@ -10,6 +10,7 @@ add_executable(test_operator
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_gelu.cu test_gelu.cu
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
../test_common.cu) ../test_common.cu)
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transformer_engine.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename InputType>
void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H,
const double epsilon) {
using compute_t = float;
for (size_t i = 0; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += (current) * (current);
}
sum = sum / H;
compute_t rs = rsqrtf(sum + epsilon);
rsigma[i] = rs;
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
const float *rsigma, const size_t N, const size_t H, float *amax,
float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = current * rsigma[i] * static_cast<compute_t>(gamma[j]);
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
const size_t N, const size_t H) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
for (size_t i = 0; i < N; ++i) {
// Reductions
compute_t mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
mdyy += dy * y;
}
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) {
gamma_grad[j] = static_cast<InputType>(dgamma[j]);
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
return;
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input({N, H}, itype);
Tensor z({N, H}, otype);
Tensor gamma({H}, wtype);
Tensor rsigma({N}, DType::kFloat32);
Tensor dz({N, H}, wtype);
Tensor dx({N, H}, itype);
Tensor dgamma({H}, wtype);
Tensor workspace, barrier, dgamma_part;
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&dz);
setRandomScale(&z);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
// Forward kernel
float epsilon = 1e-5;
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
// Backward kernel
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
// Reference implementations
// use the GPU stats to tighten the tolerances
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
ref_dgamma.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats);
auto [atol, rtol] = getTolerances(otype);
atol = 1e-8;
compareResults("output", z, ref_output.get(), atol, rtol);
double atol_bwd = 5e-6;
double rtol_bwd = 1e-4;
compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409}, // Primes 40, 80
{71, 3571}, // Primes 20, 500
{29, 17389}}; // Primes 10, 2000
} // namespace
class RMSNormTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(RMSNormTestSuite, TestRMSNorm) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
...@@ -31,6 +31,9 @@ add_library(transformer_engine SHARED ...@@ -31,6 +31,9 @@ add_library(transformer_engine SHARED
layer_norm/ln_api.cpp layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu util/cast.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_upper_triang_masked_softmax.cu)
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file rmsnorm.h
* \brief RMSNorm functions.
*/
#ifndef TRANSFORMER_ENGINE_RMSNORM_H_
#define TRANSFORMER_ENGINE_RMSNORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Compute RMSNorm on the input.
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm_fwd(const NVTETensor x,
const NVTETensor gamma,
const float epsilon,
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of RMSNorm.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm_bwd(const NVTETensor dz,
const NVTETensor x,
const NVTETensor rsigma,
const NVTETensor gamma,
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_RMSNORM_H_
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
************************************************************************/ ************************************************************************/
#include "ln.h" #include "ln.h"
#include "../utils.cuh"
#include "ln_kernel_traits.h" #include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh" #include "ln_bwd_kernels.cuh"
...@@ -187,6 +186,55 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur ...@@ -187,6 +186,55 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params); kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Create tuned launch function and register. Macro signature: // Create tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
************************************************************************/ ************************************************************************/
#include "ln.h" #include "ln.h"
#include "../utils.cuh"
#include "ln_kernel_traits.h" #include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh" #include "ln_fwd_kernels.cuh"
...@@ -149,6 +148,36 @@ void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configur ...@@ -149,6 +148,36 @@ void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configur
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Create tuned launch function and register. Macro signature: // Create tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "../common.h"
#include "../layer_norm/ln.h"
namespace transformer_engine {
namespace rmsnorm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params>
struct LaunchParams : public transformer_engine::layer_norm::LaunchParams<Params> {};
struct FwdParams : public transformer_engine::layer_norm::FwdParams {};
struct BwdParams : public transformer_engine::layer_norm::BwdParams {};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
using FwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, FwdFunction>>;
using BwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, BwdFunction>>;
extern FwdTunedRegistry FWD_TUNED_FUNCS;
extern BwdTunedRegistry BWD_TUNED_FUNCS;
extern FwdGeneralRegistry FWD_GENERAL_FUNCS;
extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <numeric>
#include <vector>
#include "../common.h"
#include "rmsnorm.h"
#include "transformer_engine/rmsnorm.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp32 fp16
fp32 fp32 fp32 bf16
fp32 fp32 fp32 fp8
fp16 fp32 fp16 fp8
bf16 fp32 bf16 fp8
Remarks:
Input type = Weight type
Compute always in FP32
*/
namespace transformer_engine {
namespace layer_norm {
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size);
}
namespace rmsnorm {
using namespace transformer_engine;
FwdTunedRegistry FWD_TUNED_FUNCS;
BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS;
FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
uint32_t hidden_size, uint32_t batch_size) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0 && FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
uint32_t hidden_size, uint32_t batch_size) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0 && BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) {
return std::reduce(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
}
} // namespace rmsnorm
////////////////////////////////////////////////////////////////////////////////////////////////////
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier) {
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, hidden_size, rows);
// Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = nullptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
}
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
const auto rows = x.data.shape[0];
const auto cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, hidden_size, rows);
// Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = nullptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
}
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#include "../utils.cuh"
namespace transformer_engine {
namespace rmsnorm {
using namespace transformer_engine;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdyy_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
}
}
reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
}
}
template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel(
BwdParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dgamma_part;
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
}
}
void *smem_gamma = smem_;
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_gamma
void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
// More than one iter iff ROWS_PER_CTA < 32.
for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
compute_t g_i = dgamma_local.data.elt[it];
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dgamma_vec2;
Vec<dst_t, NUM_ELT> dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
}
}
}
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn =
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t rs = 0.f;
if (row < params.rows) {
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
if constexpr (WARPS_M == 1) {
// Write out local weight grad contributions
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col,
params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
// Reduce dzy
__syncthreads();
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
if (it != warp_m) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
#pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col,
params.cols - col);
}
}
}
template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__ __launch_bounds__(
WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma;
dgamma.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
}
}
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
if (warp_m < nrows) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
}
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
}
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "rmsnorm.h"
#include "rmsnorm_bwd_kernels.cuh"
#include "rmsnorm_kernel_traits.h"
using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
return;
}
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f =
Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col * WARPS_M * ctas_per_row *
sizeof(typename Kernel_traits::reduce_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&rmsnorm_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Create rmsnorm tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "rmsnorm.h"
#include "rmsnorm_fwd_kernels.cuh"
#include "rmsnorm_kernel_traits.h"
using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2;
}
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD,
stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
}
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Create rmsnorm tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16);
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_
#include <cfloat>
#include <cstdio>
#include "../utils.cuh"
namespace transformer_engine {
namespace rmsnorm {
using namespace transformer_engine;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel(
FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
x[it].load_from(params.x, idx);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = Get<1>::of<stats_t, compute_t>(s);
// reciprocal of root mean square
// we could optimize here to count mean square directly
compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon);
if (bidn == 0 && warp_n == 0 && lane == 0) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel(
FwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn =
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
}
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t diff = x[it].data.elt[jt];
sqsigma += diff * diff;
}
}
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if (gidn == 0 && row < params.rows) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
rs_ptr[row] = rs;
}
// Compute output
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
// Compute output values
Cvec z;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij;
}
// Apply fp8 factors
if (params.fp8_out) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
}
}
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
#include "../common.h"
#include "../layer_norm/ln_kernel_traits.h"
#include "../utils.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace transformer_engine {
namespace rmsnorm {
template <
uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_,
typename Base =
layer_norm::Kernel_traits_finalize<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_,
index_t_, THREADS_PER_CTA_, BYTES_PER_LDG_> >
struct Kernel_traits_finalize : public Base {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename Base = layer_norm::Kernel_traits<weight_t_, input_t_, output_t_, compute_t_,
index_t_, HIDDEN_SIZE_, CTAS_PER_ROW_, WARPS_M_,
WARPS_N_, BYTES_PER_LDG_> >
struct Kernel_traits : public Base {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
...@@ -19,83 +19,6 @@ constexpr uint32_t THREADS_PER_WARP = 32; ...@@ -19,83 +19,6 @@ constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
// NOLINTBEGIN
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// NOLINTEND
inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*) inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment