Commit 5749aec6 authored by wenjh's avatar wenjh
Browse files

[DTK-25.04] Workaround compiler bugs.



Ref params of rmsnorm will make program corruption with 'nil' error.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 8de7a1ce
......@@ -13,8 +13,9 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<BackwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
......@@ -78,8 +79,9 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<BackwardKernelParams>& launch_params = *plaunch_params;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
......@@ -144,7 +146,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
&launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
......@@ -13,8 +13,9 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
void launch_tuned_(LaunchParams<ForwardKernelParams>* plaunch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<ForwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
......@@ -65,8 +66,9 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<ForwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
......@@ -114,7 +116,7 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
&launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment