"vscode:/vscode.git/clone" did not exist on "cfb3b75d63bf097fec2efb0835d7bf62a0bd3492"
Commit 8df1b6b8 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix NaN issues in FusedRMSNorm

parent 28c5638d
...@@ -712,7 +712,7 @@ void cuComputeGradInput( ...@@ -712,7 +712,7 @@ void cuComputeGradInput(
#ifndef __HIP_PLATFORM_HCC__ #ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx; int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) { for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]); const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]); const U c_loss = static_cast<U>(k_dout[l+k]);
if (!rms_only) { if (!rms_only) {
...@@ -741,8 +741,12 @@ void cuComputeGradInput( ...@@ -741,8 +741,12 @@ void cuComputeGradInput(
const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0)); const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0)); const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0)); const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss * gamma_idx; if (!rms_only) {
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar; sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * gamma_idx * (c_h) * c_invvar;
}
} }
#endif #endif
} else { } else {
...@@ -775,8 +779,12 @@ void cuComputeGradInput( ...@@ -775,8 +779,12 @@ void cuComputeGradInput(
int idx = l + thrx; int idx = l + thrx;
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0)); const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0)); const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss; if (!rms_only) {
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h) * c_invvar;
}
} }
#endif #endif
} }
...@@ -895,7 +903,7 @@ void HostApplyLayerNorm( ...@@ -895,7 +903,7 @@ void HostApplyLayerNorm(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size); output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
} }
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files // Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U, typename V=T> template<typename T, typename U, typename V=T>
void HostApplyRMSNorm( void HostApplyRMSNorm(
V* output, V* output,
...@@ -1070,7 +1078,7 @@ void HostLayerNormGradient( ...@@ -1070,7 +1078,7 @@ void HostLayerNormGradient(
grad_input, grad_input,
false); false);
} }
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files // Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U=float, typename V=T> template<typename T, typename U=float, typename V=T>
void HostRMSNormGradient( void HostRMSNormGradient(
const V* dout, const V* dout,
...@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient( ...@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient(
) )
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment