Commit 46eca0a1 authored by rocking's avatar rocking
Browse files

Refactor relu. using type_trait instead of overloading

parent 3e4d2752
......@@ -146,22 +146,23 @@ struct AddHardswishAdd
struct Relu
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x > 0 ? x : 0; }
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{
float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32);
}
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x > 0 ? x : 0; }
};
struct Normalize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment