Commit 18d2bb1b authored by Chao Liu's avatar Chao Liu
Browse files

ad gelu and fast_gelu

parent 9f71ff48
......@@ -27,6 +27,44 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct Gelu
{
__host__ __device__ void operator()(float& y, const float& x) const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const float a = float(0.035677) * x * x;
const float b = float(0.797885) + a;
const float c = b * x;
const float d = tanh(c);
const float e = float(1.0) + d;
y = float(0.5) * x * e;
}
};
struct FastGelu
{
__host__ void operator()(float& y, const float& x) const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const float a = float(0.035677) * x * x;
const float b = float(0.797885) + a;
const float c = b * x;
const float d = tanh(c);
const float e = float(1.0) + d;
y = float(0.5) * x * e;
}
__device__ void operator()(float& y, const float& x) const
{
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2)/(float(1) + emu) - float(1));
y = x * cdf;
}
};
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
......@@ -38,7 +76,11 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
#if 0
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
#else
using CElementOp = FastGelu;
#endif
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment