Commit fc1f07ac authored by Chao Liu's avatar Chao Liu
Browse files

update fastgelu

parent e38e61b6
......@@ -33,54 +33,13 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = Relu(C + D);
struct AddRelu
{
template <typename E, typename C, typename D>
__host__ __device__ void operator()(E& e, const C& c, const D& d) const;
#if 0
template <>
__host__ __device__ void
operator()<ck::half_t, ck::half_t, ck::half_t>(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{
const ck::half_t x = c + d;
e = x > 0 ? x : 0;
}
#else
// AddFastGeLU
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, ck::half_t>(
ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{
const ck::half_t x = c + d;
e = x > 0 ? x : 0;
}
#endif
};
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
};
using ADataType = F16;
......
......@@ -3,12 +3,11 @@
#include "common.hpp"
extern "C" __device__ float __ocml_native_recip_f32(float);
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using CDataType = F16; // C matrix doesn't exsitm this is used for verification
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
......@@ -21,67 +20,9 @@ using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
// C = A * B
// E = FastGelu(C + D0 + D1)
struct EleFastGeLU
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#if 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#elif 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
return x * cdf;
}
#else
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
return x * cdf;
}
#endif
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
#if 0
const float y =
GetFastGeLU(ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1));
#else
const float a =
ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
const float y = a > 0 ? a : 0;
#endif
e = ck::type_convert<E>(y);
}
};
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = EleFastGeLU;
using CDEElementOp = AddAddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......@@ -96,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
......
......@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(2)
switch(config.init_method)
{
case 0: break;
case 1:
......@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
if(config.do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<CDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -105,7 +105,7 @@ struct AddAddFastGelu
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddFastGelu;
using CDE0ElementOp = AddAddRelu;
using A1ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
......
......@@ -232,16 +232,18 @@ struct AddFastGelu
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template<>
__host__ __device__ constexpr void operator<float, float, float>()(float& e, const float& c, const float& d) const
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c + d;
FastGelu{}.template operator()<float, float>(e, x);
}
template<>
__host__ __device__ constexpr void operator<half_t, half_t, half_t>()(half_t& e, const half_t& c, const half_t& d) const
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
......
......@@ -15,7 +15,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
// Example:
//
// struct ExampleElementwiseOp
// {
......@@ -29,19 +29,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd
{
......@@ -142,7 +129,6 @@ struct AddHardswishAdd
}
};
// C = A * B
// E = C + D0 + D1
struct AddAdd
{
......@@ -171,41 +157,33 @@ struct AddAdd
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float x = c + d0 + d1;
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t x = c + d0 + d1;
e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
};
......
......@@ -210,9 +210,9 @@ struct FastGelu
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
}
......@@ -231,11 +231,19 @@ struct FastGelu
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1));
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
y = x * cdf;
#endif
}
// device code, use lower precision "__expf" and "rcp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment