Commit fe32b124 authored by Chao Liu's avatar Chao Liu
Browse files

clean element wise op

parent d1335c43
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
* *
*******************************************************************************/ *******************************************************************************/
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
struct Add struct Add
...@@ -211,6 +211,5 @@ struct AddHardswish ...@@ -211,6 +211,5 @@ struct AddHardswish
}; };
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -9,18 +9,56 @@ namespace ck { ...@@ -9,18 +9,56 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
//
// struct ExampleElementwiseOp
// {
// template<typename Y, typename X>
// __host__ __device__ constexpr void
// operator()(Y&, const X) const;
//
// template<>
// __host__ __device__ constexpr void
// operator()<half_t, half_t>(half_t& y, const half_t& x) const
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd struct AddReluAdd
{ {
__host__ __device__ constexpr void template <typename Y, typename X0, typename X1, typename X2>
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{ {
half_t a = x0 + x1; half_t a = x0 + x1;
half_t b = a > 0 ? a : 0; half_t b = a > 0 ? a : 0;
y = b + x2; y = b + x2;
} }
__host__ __device__ constexpr void template <>
operator()(float& y, const float& x0, const float& x1, const float& x2) const __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
const float& x0,
const float& x1,
const float& x2) const
{ {
float a = x0 + x1; float a = x0 + x1;
float b = a > 0 ? a : 0; float b = a > 0 ? a : 0;
...@@ -28,8 +66,9 @@ struct AddReluAdd ...@@ -28,8 +66,9 @@ struct AddReluAdd
y = c; y = c;
} }
__host__ __device__ constexpr void template <>
operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
{ {
float a = x0 + x1; float a = x0 + x1;
float b = a > 0 ? a : 0; float b = a > 0 ? a : 0;
...@@ -40,8 +79,14 @@ struct AddReluAdd ...@@ -40,8 +79,14 @@ struct AddReluAdd
struct AddHardswishAdd struct AddHardswishAdd
{ {
__host__ __device__ constexpr void template <typename Y, typename X0, typename X1, typename X2>
operator()(float& y, const float& x0, const float& x1, const float& x2) const __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
const float& x0,
const float& x1,
const float& x2) const
{ {
float a = x0 + x1; float a = x0 + x1;
float b = a + float{3}; float b = a + float{3};
...@@ -50,8 +95,9 @@ struct AddHardswishAdd ...@@ -50,8 +95,9 @@ struct AddHardswishAdd
y = d; y = d;
} }
__host__ __device__ constexpr void template <>
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{ {
float a = x0 + x1; float a = x0 + x1;
float b = a + float{3}; float b = a + float{3};
...@@ -66,7 +112,7 @@ struct AddHardswishAdd ...@@ -66,7 +112,7 @@ struct AddHardswishAdd
struct AddAddFastGelu struct AddAddFastGelu
{ {
template <typename E, typename C, typename D0, typename D1> template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; __host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const;
template <> template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e, __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
...@@ -92,6 +138,7 @@ struct AddAddFastGelu ...@@ -92,6 +138,7 @@ struct AddAddFastGelu
struct Normalize struct Normalize
{ {
// FIXME: is double absolutely necessary?
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T> template <typename T>
...@@ -126,6 +173,7 @@ struct Normalize ...@@ -126,6 +173,7 @@ struct Normalize
y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
}; };
// FIXME: is double absolutely necessary?
double epsilon_; double epsilon_;
}; };
...@@ -138,7 +186,7 @@ struct UnaryTypeConvert<float, ck::bhalf_t> ...@@ -138,7 +186,7 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{ {
y = ck::type_convert<float, ck::bhalf_t>(x); y = ck::type_convert<float, ck::bhalf_t>(x);
}; }
}; };
template <> template <>
...@@ -147,7 +195,7 @@ struct UnaryTypeConvert<ck::bhalf_t, float> ...@@ -147,7 +195,7 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{ {
y = ck::type_convert<ck::bhalf_t, float>(x); y = ck::type_convert<ck::bhalf_t, float>(x);
}; }
}; };
} // namespace element_wise } // namespace element_wise
......
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp" #include "math_v2.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment