Commit 46eca0a1 authored by rocking's avatar rocking
Browse files

Refactor relu. using type_trait instead of overloading

parent 3e4d2752
...@@ -146,22 +146,23 @@ struct AddHardswishAdd ...@@ -146,22 +146,23 @@ struct AddHardswishAdd
struct Relu struct Relu
{ {
__host__ __device__ void operator()(float& y, const float& x) const { y = x > 0 ? x : 0; } template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x > 0 ? x : 0; } {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32); y = ck::type_convert<bhalf_t>(y_f32);
} }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x > 0 ? x : 0; }
}; };
struct Normalize struct Normalize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment