Unverified Commit 1f306024 authored by Lakhinder Walia's avatar Lakhinder Walia Committed by GitHub
Browse files

fast_gelu: minor code reorg to enhance ref & gpu performance (#1162)

parent 1b0fbaeb
...@@ -458,27 +458,29 @@ struct FastGelu ...@@ -458,27 +458,29 @@ struct FastGelu
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float c1 = -2.0 * 0.035677f;
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
y = x * cdf; const float emu = exp(u);
y = x / (1.f + emu);
} }
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); // const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u); const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __expf(u);
#if !CK_WORKAROUND_SWDEV_383542 #if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); y = x * __frcp_rn(1.f + emu);
#else #else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); y = x * __ocml_native_recip_f32(1.f + emu);
#endif #endif
y = x * cdf;
} }
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment