Commit e38e61b6 authored by Chao Liu's avatar Chao Liu
Browse files

update fastgelu

parent b4f96931
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -225,43 +226,26 @@ struct AddHardswish ...@@ -225,43 +226,26 @@ struct AddHardswish
}; };
}; };
// C = A * B
// E = FastGelu(C + D) // E = FastGelu(C + D)
struct AddFastGelu struct AddFastGelu
{ {
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D> template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d)); template<>
__host__ __device__ constexpr void operator<float, float, float>()(float& e, const float& c, const float& d) const
{
const float x = c + d;
e = type_convert<E>(y); FastGelu{}.template operator()<float, float>(e, x);
} }
template <typename D> template<>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const __host__ __device__ constexpr void operator<half_t, half_t, half_t>()(half_t& e, const half_t& c, const half_t& d) const
{ {
static_assert(is_valid_param_type_v<D>); const half_t x = c + d;
e = GetFastGeLU(c + type_convert<float>(d)); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
}; };
......
...@@ -197,6 +197,8 @@ struct Relu ...@@ -197,6 +197,8 @@ struct Relu
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// device code use lower accuracy "__expf" and "rcp" function
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -205,7 +207,6 @@ struct FastGelu ...@@ -205,7 +207,6 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
// host code, use higher precision "exp" and "div"
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -216,23 +217,36 @@ struct FastGelu ...@@ -216,23 +217,36 @@ struct FastGelu
y = x * cdf; y = x * cdf;
} }
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
#if 0
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = __expf(-u); const float emu = __expf(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1)); const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1));
y = x * cdf; y = x * cdf;
#else }
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf; // device code, use lower precision "__expf" and "rcp"
#endif template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment