"...composable_kernel.git" did not exist on "823657ed120144943b7db87c07fe3e647128db56"
Commit 4198de5a authored by Chao Liu's avatar Chao Liu
Browse files

fast gelu using builtin function

parent 3a430165
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F16;
using CDataType = F32; // C matrix doesn't exsit in memory, this is used for host verification using CDataType = F16; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
......
...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
if(config.do_verification) if(config.do_verification)
{ {
Tensor<AccDataType> c_m_n({M, N}); Tensor<CDataType> c_m_n({M, N});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
......
...@@ -168,6 +168,9 @@ ...@@ -168,6 +168,9 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -280,6 +280,7 @@ struct AddHardswish ...@@ -280,6 +280,7 @@ struct AddHardswish
}; };
}; };
#if 0
// C = A * B // C = A * B
// E = FastGelu(C + D) // E = FastGelu(C + D)
struct AddFastGelu struct AddFastGelu
...@@ -319,6 +320,32 @@ struct AddFastGelu ...@@ -319,6 +320,32 @@ struct AddFastGelu
e = GetFastGeLU(c + type_convert<float>(d)); e = GetFastGeLU(c + type_convert<float>(d));
} }
}; };
#else
// E = FastGelu(C + D)
struct AddFastGelu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c + d;
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
};
#endif
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -16,7 +16,7 @@ namespace element_wise { ...@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion // siliently do implicit type conversion
// //
// Method 1: // Example:
// //
// struct ExampleElementwiseOp // struct ExampleElementwiseOp
// { // {
...@@ -30,19 +30,6 @@ namespace element_wise { ...@@ -30,19 +30,6 @@ namespace element_wise {
// { // {
// } // }
// }; // };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd struct AddReluAdd
{ {
...@@ -208,6 +195,7 @@ struct AddMultiply ...@@ -208,6 +195,7 @@ struct AddMultiply
} }
}; };
#if 0
// C = A * B // C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
...@@ -245,6 +233,35 @@ struct AddAddFastGelu ...@@ -245,6 +233,35 @@ struct AddAddFastGelu
e = type_convert<E>(y); e = type_convert<E>(y);
} }
}; };
#else
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c + d0 + d1;
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
};
#endif
struct Normalize struct Normalize
{ {
......
...@@ -11,6 +11,10 @@ namespace ck { ...@@ -11,6 +11,10 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#if CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThrough struct PassThrough
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -200,6 +204,7 @@ struct Relu ...@@ -200,6 +204,7 @@ struct Relu
} }
}; };
#if 0
// Y = FastGelu(X) // Y = FastGelu(X)
struct FastGelu struct FastGelu
{ {
...@@ -232,6 +237,76 @@ struct FastGelu ...@@ -232,6 +237,76 @@ struct FastGelu
y = type_convert<Y>(tmp_y); y = type_convert<Y>(tmp_y);
} }
}; };
#else
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
}
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif
y = x * cdf;
#endif
}
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
};
#endif
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2))) // y = 0.5*x*(1+erf(x/sqrt(2)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment