Unverified Commit 1f306024 authored by Lakhinder Walia's avatar Lakhinder Walia Committed by GitHub
Browse files

fast_gelu: minor code reorg to enhance ref & gpu performance (#1162)

parent 1b0fbaeb
......@@ -458,27 +458,29 @@ struct FastGelu
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __expf(u);
#if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
y = x * __frcp_rn(1.f + emu);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
y = x * __ocml_native_recip_f32(1.f + emu);
#endif
y = x * cdf;
}
template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment