Commit fc1f07ac authored by Chao Liu's avatar Chao Liu
Browse files

update fastgelu

parent e38e61b6
...@@ -33,54 +33,13 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -33,54 +33,13 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = Relu(C + D); // E = Relu(C + D);
struct AddRelu struct AddRelu
{ {
template <typename E, typename C, typename D>
__host__ __device__ void operator()(E& e, const C& c, const D& d) const;
#if 0
template <>
__host__ __device__ void __host__ __device__ void
operator()<ck::half_t, ck::half_t, ck::half_t>(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{ {
const ck::half_t x = c + d; const ck::half_t x = c + d;
e = x > 0 ? x : 0; e = x > 0 ? x : 0;
} }
#else
// AddFastGeLU
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, ck::half_t>(
ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{
const ck::half_t x = c + d;
e = x > 0 ? x : 0;
}
#endif
};
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
}; };
using ADataType = F16; using ADataType = F16;
......
...@@ -3,12 +3,11 @@ ...@@ -3,12 +3,11 @@
#include "common.hpp" #include "common.hpp"
extern "C" __device__ float __ocml_native_recip_f32(float);
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F16;
using CDataType = F16; // C matrix doesn't exsitm this is used for verification
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
...@@ -21,67 +20,9 @@ using D1Layout = Row; ...@@ -21,67 +20,9 @@ using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row; using ELayout = Row;
// C = A * B
// E = FastGelu(C + D0 + D1)
struct EleFastGeLU
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#if 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#elif 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
return x * cdf;
}
#else
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
return x * cdf;
}
#endif
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
#if 0
const float y =
GetFastGeLU(ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1));
#else
const float a =
ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
const float y = a > 0 ? a : 0;
#endif
e = ck::type_convert<E>(y);
}
};
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = EleFastGeLU; using CDEElementOp = AddAddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -96,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -96,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
...@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(2) switch(config.init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
if(config.do_verification) if(config.do_verification)
{ {
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<CDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
......
...@@ -105,7 +105,7 @@ struct AddAddFastGelu ...@@ -105,7 +105,7 @@ struct AddAddFastGelu
using A0ElementOp = PassThrough; using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddFastGelu; using CDE0ElementOp = AddAddRelu;
using A1ElementOp = PassThrough; using A1ElementOp = PassThrough;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
......
...@@ -232,16 +232,18 @@ struct AddFastGelu ...@@ -232,16 +232,18 @@ struct AddFastGelu
template <typename E, typename C, typename D> template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template<> template <>
__host__ __device__ constexpr void operator<float, float, float>()(float& e, const float& c, const float& d) const __host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{ {
const float x = c + d; const float x = c + d;
FastGelu{}.template operator()<float, float>(e, x); FastGelu{}.template operator()<float, float>(e, x);
} }
template<> template <>
__host__ __device__ constexpr void operator<half_t, half_t, half_t>()(half_t& e, const half_t& c, const half_t& d) const __host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{ {
const half_t x = c + d; const half_t x = c + d;
......
...@@ -15,7 +15,7 @@ namespace element_wise { ...@@ -15,7 +15,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion // siliently do implicit type conversion
// //
// Method 1: // Example:
// //
// struct ExampleElementwiseOp // struct ExampleElementwiseOp
// { // {
...@@ -29,19 +29,6 @@ namespace element_wise { ...@@ -29,19 +29,6 @@ namespace element_wise {
// { // {
// } // }
// }; // };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd struct AddReluAdd
{ {
...@@ -142,7 +129,6 @@ struct AddHardswishAdd ...@@ -142,7 +129,6 @@ struct AddHardswishAdd
} }
}; };
// C = A * B
// E = C + D0 + D1 // E = C + D0 + D1
struct AddAdd struct AddAdd
{ {
...@@ -171,41 +157,33 @@ struct AddAdd ...@@ -171,41 +157,33 @@ struct AddAdd
} }
}; };
// C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1> template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{ {
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> && const float x = c + d0 + d1;
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float y = FastGelu{}.template operator()<float, float>(e, x);
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1)); }
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t x = c + d0 + d1;
e = type_convert<E>(y); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
}; };
......
...@@ -210,9 +210,9 @@ struct FastGelu ...@@ -210,9 +210,9 @@ struct FastGelu
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf; y = x * cdf;
} }
...@@ -231,11 +231,19 @@ struct FastGelu ...@@ -231,11 +231,19 @@ struct FastGelu
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); #if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u); const float emu = __expf(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1)); const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
y = x * cdf; y = x * cdf;
#endif
} }
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment