Unverified Commit 3102fdd1 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[C] Normalization Refactor + Adding CUDNN backend (#1315)



* cuDNN normalization integration
* TE Norm refactor
* TE Norm APIs changes.

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d8b13cb0
......@@ -4,19 +4,18 @@
* See LICENSE for license information.
************************************************************************/
#include "rmsnorm.h"
#include "../common.h"
#include "../kernel_traits.h"
#include "rmsnorm_bwd_kernels.cuh"
#include "rmsnorm_kernel_traits.h"
using namespace transformer_engine::rmsnorm;
using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params,
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
......@@ -27,14 +26,14 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
launch_params.dgamma_part_bytes =
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}
......@@ -63,7 +62,7 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
......@@ -71,7 +70,7 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params,
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
......@@ -95,13 +94,13 @@ void launch_general_(LaunchParams<BwdParams> &launch_params,
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
}
launch_params.dgamma_part_bytes =
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}
......@@ -130,91 +129,78 @@ void launch_general_(LaunchParams<BwdParams> &launch_params,
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
configure_params); \
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
OTYPE, CTYPE, ...) \
namespace { \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
......@@ -4,16 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include "rmsnorm.h"
#include "../common.h"
#include "../kernel_traits.h"
#include "rmsnorm_fwd_kernels.cuh"
#include "rmsnorm_kernel_traits.h"
using namespace transformer_engine::rmsnorm;
using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params,
void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
......@@ -26,10 +26,8 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params,
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2;
......@@ -59,7 +57,7 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params,
void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
......@@ -80,11 +78,8 @@ void launch_general_(LaunchParams<FwdParams> &launch_params,
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t);
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
}
......@@ -104,124 +99,112 @@ void launch_general_(LaunchParams<FwdParams> &launch_params,
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
OTYPE, CTYPE, ...) \
namespace { \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG) \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG>(launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// Create rmsnorm general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16);
......@@ -10,15 +10,15 @@
#include <cfloat>
#include <cstdio>
#include "../utils.cuh"
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace rmsnorm {
using namespace transformer_engine;
namespace normalization {
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel(
FwdParams params) {
ForwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -143,7 +143,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel(
FwdParams params) {
ForwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
......@@ -291,7 +291,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
} // namespace rmsnorm
} // namespace normalization
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "../common.h"
#include "../layer_norm/ln.h"
namespace transformer_engine {
namespace rmsnorm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params>
struct LaunchParams : public transformer_engine::layer_norm::LaunchParams<Params> {};
struct FwdParams : public transformer_engine::layer_norm::FwdParams {};
struct BwdParams : public transformer_engine::layer_norm::BwdParams {};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
using FwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, FwdFunction>>;
using BwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, BwdFunction>>;
extern FwdTunedRegistry FWD_TUNED_FUNCS;
extern BwdTunedRegistry BWD_TUNED_FUNCS;
extern FwdGeneralRegistry FWD_GENERAL_FUNCS;
extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstdint>
#include <numeric>
#include <vector>
#include "../common.h"
#include "rmsnorm.h"
#include "transformer_engine/rmsnorm.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp32 fp16
fp32 fp32 fp32 bf16
fp32 fp32 fp32 fp8
fp16 fp32 fp16 fp8
bf16 fp32 bf16 fp8
Remarks:
Input type = Weight type
Compute always in FP32
*/
namespace transformer_engine {
namespace layer_norm {
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size);
}
namespace rmsnorm {
using namespace transformer_engine;
FwdTunedRegistry FWD_TUNED_FUNCS;
BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS;
FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::FwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::BwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) &&
is_aligned(params.dgamma) && is_aligned(params.dgamma_part) &&
BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) {
return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
}
} // namespace rmsnorm
////////////////////////////////////////////////////////////////////////////////////////////////////
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) {
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = nullptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
const auto rows = x.data.shape[0];
const auto cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = nullptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
auto pdw_shape =
std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier), false);
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), false);
}
void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier), true);
}
void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), true);
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
#include "../common.h"
#include "../layer_norm/ln_kernel_traits.h"
#include "../utils.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace transformer_engine {
namespace rmsnorm {
template <
uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_,
typename Base =
layer_norm::Kernel_traits_finalize<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_,
index_t_, THREADS_PER_CTA_, BYTES_PER_LDG_> >
struct Kernel_traits_finalize : public Base {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename Base = layer_norm::Kernel_traits<weight_t_, input_t_, output_t_, compute_t_,
index_t_, HIDDEN_SIZE_, CTAS_PER_ROW_, WARPS_M_,
WARPS_N_, BYTES_PER_LDG_> >
struct Kernel_traits : public Base {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace rmsnorm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_
......@@ -16,7 +16,6 @@ from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......@@ -82,7 +81,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
......@@ -96,18 +95,15 @@ class LayerNormFwdPrimitive(BasePrimitive):
wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = out_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval
return out_aval, mu_aval, rsigma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, _, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs)
out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval
@staticmethod
......@@ -151,7 +147,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
......@@ -160,9 +156,6 @@ class LayerNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
......@@ -174,15 +167,9 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -198,7 +185,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
to describe implementation
"""
assert LayerNormFwdPrimitive.inner_primitive is not None
out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind(
out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
return out, mu, rsigma
......@@ -377,8 +364,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = (
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype
......@@ -388,28 +374,15 @@ class LayerNormBwdPrimitive(BasePrimitive):
kwargs["epsilon"],
get_backward_sm_margin(),
)
)
wkspace_aval = dx_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = dx_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
dgamma_part_aval = dgamma_aval.update(
shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])
)
dbeta_part_aval = dbeta_aval.update(
shape=dbeta_part_info[0], dtype=te_dtype_to_jax_dtype(dbeta_part_info[1])
)
return (
dx_aval,
dgamma_aval,
dbeta_aval,
wkspace_aval,
barrier_aval,
dgamma_part_aval,
dbeta_part_aval,
)
@staticmethod
......@@ -417,9 +390,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = LayerNormBwdPrimitive.abstract(
*args, **kwargs
)
dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval
@staticmethod
......@@ -470,20 +441,14 @@ class LayerNormBwdPrimitive(BasePrimitive):
sm_margin = get_backward_sm_margin()
wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
wkspace_aval = ctx.avals_out[-4:]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.shape,
dbeta_part_aval.shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
jax_dtype_to_te_dtype(dbeta_part_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -496,7 +461,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
@staticmethod
def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
assert LayerNormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind(
dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
return dx, dgamma, dbeta
......@@ -630,7 +595,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
......@@ -644,18 +609,15 @@ class RmsNormFwdPrimitive(BasePrimitive):
wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = out_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
return out_aval, rsigma_aval, wkspace_aval, barrier_aval
return out_aval, rsigma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval
@staticmethod
......@@ -688,7 +650,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
......@@ -696,9 +658,6 @@ class RmsNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
......@@ -710,15 +669,9 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -734,7 +687,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
to describe implementation
"""
assert RmsNormFwdPrimitive.inner_primitive is not None
out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
return out, rsigma
@staticmethod
......@@ -833,8 +786,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, _ = (
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
......@@ -844,25 +796,18 @@ class RmsNormBwdPrimitive(BasePrimitive):
kwargs["epsilon"],
get_backward_sm_margin(),
)
)
wkspace_aval = dx_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = dx_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
dgamma_part_aval = dgamma_aval.update(
shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])
)
return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval
return dx_aval, dgamma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval
@staticmethod
......@@ -896,7 +841,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:]
wkspace_aval = ctx.avals_out[-3:]
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
......@@ -904,12 +849,6 @@ class RmsNormBwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
ir.RankedTensorType.get(
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
),
]
operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
......@@ -921,15 +860,9 @@ class RmsNormBwdPrimitive(BasePrimitive):
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.shape,
(0,), # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -942,7 +875,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
@staticmethod
def impl(dz, x, rsigma, gamma, epsilon):
assert RmsNormBwdPrimitive.inner_primitive is not None
dx, dgamma, _, _, _ = RmsNormBwdPrimitive.inner_primitive.bind(
dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind(
dz, x, rsigma, gamma, epsilon=epsilon
)
return dx, dgamma
......@@ -1066,7 +999,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert gamma_aval.size == beta_aval.size
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
......@@ -1084,18 +1017,15 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = x_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = LayerNormFwdFp8Primitive.abstract(
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract(
*args, **kwargs
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval
......@@ -1158,7 +1088,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
......@@ -1168,9 +1098,6 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
......@@ -1189,15 +1116,9 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -1215,7 +1136,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
to describe implementation
"""
assert LayerNormFwdFp8Primitive.inner_primitive is not None
out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
x,
gamma,
beta,
......@@ -1394,7 +1315,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
rsigama_dtype = jnp.float32
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
......@@ -1412,18 +1333,15 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
barrier_aval = x_aval.update(
shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1])
)
return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval
return out_aval, rsigma_aval, amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval, amax_aval
@staticmethod
......@@ -1476,7 +1394,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
......@@ -1485,9 +1403,6 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -1499,15 +1414,9 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -1525,7 +1434,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
to describe implementation
"""
assert RmsNormFwdFp8Primitive.inner_primitive is not None
out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(
out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
return out, rsigma, amax
......
......@@ -81,25 +81,18 @@ struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
Shape dgamma_part_shape;
Shape dbeta_part_shape;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
DType barrier_dtype;
DType dgamma_part_dtype;
DType dbeta_part_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
......
......@@ -3,9 +3,9 @@
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/normalization.h"
#include "extensions.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
namespace transformer_engine {
namespace jax {
......@@ -25,40 +25,36 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else {
// TODO(Phuong): Verify and remove this check
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
nullptr);
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()));
}
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *workspace, DType work_dtype, void *barrier,
DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, int sm_margin, cudaStream_t stream) {
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *bias, void *output, DType out_dtype,
void *workspace, DType work_dtype, void *mu, void *rsigma, float *amax,
float *scale, float *scale_inv, int sm_margin, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size};
auto barrier_shape = std::vector<size_t>{barrier_size};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
......@@ -71,23 +67,21 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm,
workspace_tensor.data(), barrier_tensor.data());
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
stream);
}
}
......@@ -96,20 +90,17 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff
Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf,
Result_Type *output_buf, Result_Type *mu_buf,
Result_Type *rsigma_buf, Result_Type *amax_out_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm, bool is_fp8) {
Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_, bool is_layer_norm, bool is_fp8) {
auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *output = (*output_buf)->untyped_data();
auto *rsigma = (*rsigma_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
void *bias = nullptr;
void *mu = nullptr;
......@@ -135,17 +126,15 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
return ffi_with_cuda_error_check();
}
......@@ -154,11 +143,10 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf,
&scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
true // is_fp8
);
......@@ -178,7 +166,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -187,15 +174,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI
Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf,
nullptr, // amax_buf
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf, &mu_buf, &rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
false // is_fp8
);
......@@ -211,7 +197,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI,
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -221,14 +206,14 @@ Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_T
Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type rsigma_buf, Result_Type amax_out_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
&amax_buf, &scale_buf, &scale_inv_buf, &output_buf,
nullptr, // mu_buf,
&rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf,
zero_centered_gamma, eps_, sm_margin_,
&rsigma_buf, &amax_out_buf, &wkspace_buf, zero_centered_gamma,
eps_, sm_margin_,
false, // is_layer_norm
true // is_fp8
);
......@@ -246,7 +231,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -254,8 +238,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
nullptr, // amax_buf,
......@@ -265,7 +249,7 @@ Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type
nullptr, // mu_buf,
&rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
false // is_fp8
);
......@@ -279,7 +263,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI,
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -303,50 +286,34 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not
std::vector<size_t> dbeta_part_shape;
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dbeta_tensor.data(), dummy_dgamma_part_tensor.data(),
dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
dbeta_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
nullptr);
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(),
xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(),
nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
xgrad_tensor.data(), wgrad_tensor.data(), dummy_work_tensor.data(), num_sm,
zero_centered_gamma, nullptr);
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()));
}
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin,
cudaStream_t stream) {
DType wkspace_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad,
void *dbeta, int sm_margin, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
......@@ -368,28 +335,23 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(),
stream, num_sm, workspace_tensor.data(), barrier_tensor.data());
dbeta_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
stream);
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(),
xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm,
zero_centered_gamma, stream);
}
}
......@@ -397,15 +359,11 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu
Buffer_Type *mu_buf, Buffer_Type *rsigma_buf,
Buffer_Type *gamma_buf, Result_Type *xgrad_buf,
Result_Type *wgrad_buf, Result_Type *dbeta_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm) {
Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_, bool is_layer_norm) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type());
auto *ograd = dz_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
......@@ -414,62 +372,37 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu
auto *xgrad = (*xgrad_buf)->untyped_data();
auto *wgrad = (*wgrad_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
auto *dgamma_part = (*dgamma_part_buf)->untyped_data();
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
auto dbeta_part_dtype = DType::kByte;
if (is_layer_norm) {
mu = (*mu_buf).untyped_data();
dbeta = (*dbeta_buf)->untyped_data();
dbeta_part = (*dbeta_part_buf)->untyped_data();
dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type());
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
Shape dgamma_part_shape;
auto dgamma_part_dims = (*dgamma_part_buf)->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
Shape dbeta_part_shape;
if (is_layer_norm) {
auto dbeta_part_dims = (*dbeta_part_buf)->dimensions();
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
} else {
dbeta_part_shape.from_vector({0, 0});
}
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf,
&dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_,
sm_margin_,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf,
zero_centered_gamma, eps_, sm_margin_,
true // is_layer_norm
);
}
......@@ -486,9 +419,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -497,15 +427,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, Result_Type dgamma_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf,
nullptr, // mu_buf
&rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf,
nullptr, // dbeta_buf,
&wkspace_buf, &barrier_buf, &dgamma_part_buf,
nullptr, // dbeta_part_buf,
zero_centered_gamma, eps_, sm_margin_,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
false // is_layer_norm
);
}
......@@ -520,8 +447,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI,
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
......@@ -540,7 +465,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
......@@ -548,21 +472,18 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -573,7 +494,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto *mu = buffers[4];
auto *rsigma = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
float *amax = nullptr;
float *scale = nullptr;
......@@ -583,20 +503,17 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -605,15 +522,9 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
......@@ -627,15 +538,10 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -648,7 +554,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive.");
void *bias = nullptr;
......@@ -658,20 +563,17 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -680,7 +582,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto *output = buffers[2];
auto *rsigma = buffers[3];
auto *workspace = buffers[4];
auto *barrier = buffers[5];
void *bias = nullptr;
void *mu = nullptr;
......@@ -692,20 +593,17 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -716,36 +614,24 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
auto *dgamma_part = buffers[8];
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
}
} // namespace jax
......
......@@ -32,24 +32,17 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin) {
CustomCallNormDescriptor desc{};
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
......
......@@ -10,9 +10,8 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
......
......@@ -353,24 +353,23 @@ std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// This call populates workspace tensor with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
// Fill workspace and barrier
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
return {ln_out, mu, rsigma};
}
......@@ -394,24 +393,23 @@ std::vector<paddle::Tensor> te_layernorm_fwd(const paddle::Tensor &input,
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// This call populates workspace tensor with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
// Fill workspace and barrier
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
return {ln_out, mu, rsigma};
}
......@@ -424,7 +422,7 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
TensorWrapper workspace;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
......@@ -438,25 +436,18 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(),
num_sm - sm_margin, workspace.data(), barrier.data());
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
num_sm - sm_margin, zero_centered_gamma, dz.stream());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true);
auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place());
auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(),
num_sm - sm_margin, workspace.data(), barrier.data());
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
num_sm - sm_margin, zero_centered_gamma, dz.stream());
return {dx, dgamma, dbeta};
}
......@@ -477,24 +468,21 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
// This call populates workspace tensor with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
// Fill workspace and barrier
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
return {ln_out, rsigma};
}
......@@ -521,23 +509,21 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
// This call populates workspace tensor with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
// Fill workspace and barrier
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
return {ln_out, rsigma};
}
......@@ -550,7 +536,7 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace, barrier, dgamma_part;
TensorWrapper workspace;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
......@@ -563,21 +549,17 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
// This call populates tensors with the required config.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin,
workspace.data(), barrier.data());
dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma,
dz.stream());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true);
auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype());
// Actual call to bwd kernel.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin,
workspace.data(), barrier.data());
dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma,
dz.stream());
return {dx, dgamma};
}
......
......@@ -28,11 +28,10 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h>
#include <transformer_engine/permutation.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
......
......@@ -19,7 +19,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
transformer_engine::TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -31,32 +31,21 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(),
dgamma_part.dtype());
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(),
dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma, dbeta};
}
......@@ -88,9 +77,6 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
const auto &weight_ = weight.contiguous();
const auto &bias_ = bias.contiguous();
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
......@@ -113,24 +99,22 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes
transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
transformer_engine::TensorWrapper workspace;
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype());
// Launch kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, mu, rsigma};
}
......@@ -194,7 +178,7 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
transformer_engine::TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -204,27 +188,21 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(),
dgamma_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma};
}
......@@ -255,9 +233,6 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a
const int scale_inv_offset) {
using namespace transformer_engine;
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
......@@ -277,24 +252,22 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes
transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
transformer_engine::TensorWrapper workspace;
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype());
// Launch kernel
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
barrier.data());
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, rsigma};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment