Commit b4f96931 authored by Chao Liu's avatar Chao Liu
Browse files

improve fastgelu

parent e9d4e893
...@@ -33,13 +33,54 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -33,13 +33,54 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = Relu(C + D); // E = Relu(C + D);
struct AddRelu struct AddRelu
{ {
template <typename E, typename C, typename D>
__host__ __device__ void operator()(E& e, const C& c, const D& d) const;
#if 0
template <>
__host__ __device__ void __host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const operator()<ck::half_t, ck::half_t, ck::half_t>(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{ {
const ck::half_t x = c + d; const ck::half_t x = c + d;
e = x > 0 ? x : 0; e = x > 0 ? x : 0;
} }
#else
// AddFastGeLU
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, ck::half_t>(
ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{
const ck::half_t x = c + d;
e = x > 0 ? x : 0;
}
#endif
};
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
}; };
using ADataType = F16; using ADataType = F16;
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include "common.hpp" #include "common.hpp"
extern "C" __device__ float __ocml_native_recip_f32(float);
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
...@@ -19,9 +21,67 @@ using D1Layout = Row; ...@@ -19,9 +21,67 @@ using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row; using ELayout = Row;
// C = A * B
// E = FastGelu(C + D0 + D1)
struct EleFastGeLU
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#if 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#elif 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
return x * cdf;
}
#else
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
return x * cdf;
}
#endif
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
#if 0
const float y =
GetFastGeLU(ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1));
#else
const float a =
ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
const float y = a > 0 ? a : 0;
#endif
e = ck::type_convert<E>(y);
}
};
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = AddAddFastGelu; using CDEElementOp = EleFastGeLU;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
...@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(config.init_method) switch(2)
{ {
case 0: break; case 0: break;
case 1: case 1:
...@@ -108,7 +108,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -108,7 +108,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
throw std::runtime_error("wrong! this device_op instance does not support this problem"); throw std::runtime_error("wrong! this device_op instance does not support this problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
......
...@@ -105,7 +105,7 @@ struct AddAddFastGelu ...@@ -105,7 +105,7 @@ struct AddAddFastGelu
using A0ElementOp = PassThrough; using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddRelu; using CDE0ElementOp = AddAddFastGelu;
using A1ElementOp = PassThrough; using A1ElementOp = PassThrough;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
......
...@@ -10,6 +10,8 @@ namespace ck { ...@@ -10,6 +10,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
extern "C" __device__ float __ocml_native_recip_f32(float);
struct PassThrough struct PassThrough
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -198,16 +200,39 @@ struct Relu ...@@ -198,16 +200,39 @@ struct Relu
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
// host code, use higher precision "exp" and "div"
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{ {
#if 0
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = __expf(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1));
y = x * cdf;
#else
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf; y = x * cdf;
#endif
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment