/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include #include "../test_common.h" namespace test { namespace { enum NormType { LayerNorm, RMSNorm }; std::map normToString = { {NormType::LayerNorm, "LayerNorm"}, {NormType::RMSNorm, "RmsNorm"} }; template void compute_ref_stats(NormType norm_type, const InputType *data, float *mu, float *rsigma, const size_t N, const size_t H, const double epsilon){ using compute_t = float; compute_t current, m; for (size_t i = 0; i < N; ++i) { compute_t sum = 0; for (size_t j = 0; j < H; ++j) { sum += static_cast(data[i * H + j]); } if (norm_type == LayerNorm){ mu[i] = sum / H; m = mu[i]; } else { m = 0;} compute_t sum_sq = 0; for (size_t j = 0; j < H; ++j) { current = static_cast(data[i * H + j]); sum_sq += (current - m) * (current - m); } #ifdef __HIP_PLATFORM_AMD__ rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon); #else rsigma[i] = rsqrtf((sum_sq / H) + epsilon); #endif } } template inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { using compute_t = float; // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; #if FP4_TYPE_SUPPORTED if constexpr (std::is_same_v || std::is_same_v || std::is_same_v){ #else if constexpr (std::is_same_v || std::is_same_v){ #endif compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); } return g; } else { if (zero_centered_gamma_in_weight_dtype){ compute_t g = static_cast(0.f); #ifndef __HIP_PLATFORM_AMD__ InputType gi = gamma; if (zero_centered_gamma) { gi = gi + static_cast(1.f); } g = static_cast(gi); #else if (zero_centered_gamma) { g += static_cast(1.f); } #endif return g; } else { compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); } return g; } } } template void compute_ref_output(NormType norm_type, const InputType *data, const InputType *gamma, const InputType *beta, OutputType* output, const float *mu, const float *rsigma, const size_t N, const size_t H, float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { using compute_t = float; compute_t current_max = -1e100; for (size_t i = 0; i < N; ++i) { for (size_t j = 0; j < H; ++j) { compute_t current = static_cast(data[i * H + j]); compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); compute_t tmp; if (norm_type == LayerNorm) { tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); } else { // RMSNorm tmp = current * rsigma[i] * g; } // Write output (scaled only for fp8 paths) output[i * H + j] = static_cast(tmp * scale); // amax semantics: // - fp8_out (scale != 1): amax on pre-scale compute value 'tmp' // - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16) if (scale != 1.f) { current_max = fmaxf(current_max, fabsf(tmp)); } else { OutputType out_t_val = static_cast(tmp); current_max = fmaxf(current_max, fabsf(static_cast(out_t_val))); } } } if (amax) { *amax = current_max; } } template void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const OutputType *add, const InputType *data, const float *mu, const float *rsigma, const InputType *gamma, InputType *data_grad, InputType *gamma_grad, InputType *beta_grad, const size_t N, const size_t H, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { using compute_t = float; std::vector dgamma(H, 0.f); std::vector dbeta(H, 0.f); for (size_t i = 0 ; i < N; ++i) { // Reductions auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; compute_t mdy = 0, mdyy = 0; for (size_t j = 0; j < H; ++j) { const compute_t x = static_cast(data[i * H + j]); const compute_t y = (x - local_mu) * rsigma[i]; compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); const compute_t dz = static_cast(output_grad[i * H + j]); const compute_t dy = g * dz; dgamma[j] += y * dz; if (norm_type == LayerNorm) { dbeta[j] += dz; mdy += dy; } mdyy += dy * y; } mdy /= H; mdyy /= H; // Input grads for (size_t j = 0; j < H; ++j) { const compute_t x = static_cast(data[i * H + j]); const compute_t y = (x - local_mu) * rsigma[i]; compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); const compute_t dz = static_cast(output_grad[i * H + j]); const compute_t dy = g * dz; const compute_t a = static_cast(add[i * H + j]); const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy); data_grad[i * H + j] = static_cast(dx); } } // Weight grads for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); } } // namespace } // namespace test