Commit 8df1b6b8 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix NaN issues in FusedRMSNorm

parent 28c5638d
......@@ -741,8 +741,12 @@ void cuComputeGradInput(
const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
if (!rms_only) {
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * gamma_idx * (c_h) * c_invvar;
}
}
#endif
} else {
......@@ -775,8 +779,12 @@ void cuComputeGradInput(
int idx = l + thrx;
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h) * c_invvar;
}
}
#endif
}
......@@ -895,7 +903,7 @@ void HostApplyLayerNorm(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
}
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U, typename V=T>
void HostApplyRMSNorm(
V* output,
......@@ -1070,7 +1078,7 @@ void HostLayerNormGradient(
grad_input,
false);
}
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U=float, typename V=T>
void HostRMSNormGradient(
const V* dout,
......@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient(
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment