"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "90dd631004072cbd129fe975958cc2b752c50ff1"
Commit 8a913c22 authored by Chao Liu's avatar Chao Liu
Browse files

added GeLU and fast GeLU

parent 18d2bb1b
...@@ -27,43 +27,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -27,43 +27,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct Gelu
{
__host__ __device__ void operator()(float& y, const float& x) const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const float a = float(0.035677) * x * x;
const float b = float(0.797885) + a;
const float c = b * x;
const float d = tanh(c);
const float e = float(1.0) + d;
y = float(0.5) * x * e;
}
};
struct FastGelu
{
__host__ void operator()(float& y, const float& x) const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const float a = float(0.035677) * x * x;
const float b = float(0.797885) + a;
const float c = b * x;
const float d = tanh(c);
const float e = float(1.0) + d;
y = float(0.5) * x * e;
}
__device__ void operator()(float& y, const float& x) const
{
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2)/(float(1) + emu) - float(1));
y = x * cdf;
}
};
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
...@@ -76,11 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -76,11 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
#if 0
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
#else
using CElementOp = FastGelu;
#endif
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
...@@ -33,7 +33,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -33,7 +33,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(examples ${EXAMPLE_NAME})
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm) add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_alpha_beta)
......
...@@ -20,6 +20,32 @@ struct PassThrough ...@@ -20,6 +20,32 @@ struct PassThrough
__host__ __device__ void operator()(double& y, const double& x) const { y = x; } __host__ __device__ void operator()(double& y, const double& x) const { y = x; }
}; };
struct Gelu
{
__host__ __device__ void operator()(float& y, const float& x) const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const float a = float(0.035677) * x * x;
const float b = float(0.797885) + a;
const float c = b * x;
const float d = tanh(c);
const float e = float(1.0) + d;
y = float(0.5) * x * e;
}
};
struct FastGelu
{
__host__ __device__ void operator()(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2)/(float(1) + emu) - float(1));
y = x * cdf;
}
};
struct Add struct Add
{ {
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment