Commit a4772890 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent d8552699
...@@ -212,10 +212,7 @@ struct Relu ...@@ -212,10 +212,7 @@ struct Relu
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
...@@ -227,16 +224,6 @@ struct FastGelu ...@@ -227,16 +224,6 @@ struct FastGelu
y = x * cdf; y = x * cdf;
} }
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
...@@ -254,7 +241,7 @@ struct FastGelu ...@@ -254,7 +241,7 @@ struct FastGelu
} }
template <> template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
float y_f; float y_f;
...@@ -264,7 +251,7 @@ struct FastGelu ...@@ -264,7 +251,7 @@ struct FastGelu
} }
template <> template <>
__device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
float y_f; float y_f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment