Unverified Commit 3102fdd1 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[C] Normalization Refactor + Adding CUDNN backend (#1315)



* cuDNN normalization integration
* TE Norm refactor
* TE Norm APIs changes.

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d8b13cb0
......@@ -4,20 +4,19 @@
* See LICENSE for license information.
************************************************************************/
#include "rmsnorm.h"
#include "../common.h"
#include "../kernel_traits.h"
#include "rmsnorm_bwd_kernels.cuh"
#include "rmsnorm_kernel_traits.h"
using namespace transformer_engine::rmsnorm;
using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params,
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
......@@ -27,14 +26,14 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
launch_params.dgamma_part_bytes =
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}
......@@ -63,7 +62,7 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
......@@ -71,7 +70,7 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params,
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
......@@ -95,13 +94,13 @@ void launch_general_(LaunchParams<BwdParams> &launch_params,
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
}
launch_params.dgamma_part_bytes =
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}
......@@ -130,91 +129,78 @@ void launch_general_(LaunchParams<BwdParams> &launch_params,
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
OTYPE, CTYPE, ...) \
namespace { \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
......@@ -10,15 +10,15 @@
#include <cfloat>
#include <cstdio>
#include "../utils.cuh"
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace rmsnorm {
using namespace transformer_engine;
namespace normalization {
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel(
FwdParams params) {
ForwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -143,7 +143,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel(
FwdParams params) {
ForwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -291,7 +291,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
} // namespace rmsnorm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "../common.h"
#include "../layer_norm/ln.h"
namespace transformer_engine {
namespace rmsnorm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params>
struct LaunchParams : public transformer_engine::layer_norm::LaunchParams<Params> {};
struct FwdParams : public transformer_engine::layer_norm::FwdParams {};
struct BwdParams : public transformer_engine::layer_norm::BwdParams {};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
using FwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, FwdFunction>>;
using BwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, BwdFunction>>;
extern FwdTunedRegistry FWD_TUNED_FUNCS;
extern BwdTunedRegistry BWD_TUNED_FUNCS;
extern FwdGeneralRegistry FWD_GENERAL_FUNCS;
extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
This diff is collapsed.
......@@ -81,25 +81,18 @@ struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
Shape dgamma_part_shape;
Shape dbeta_part_shape;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
DType barrier_dtype;
DType dgamma_part_dtype;
DType dbeta_part_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment