Commit fe32b124 authored by Chao Liu's avatar Chao Liu
Browse files

clean element wise op

parent d1335c43
......@@ -24,11 +24,11 @@
*
*******************************************************************************/
#pragma once
#include "data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct Add
......@@ -211,6 +211,5 @@ struct AddHardswish
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -9,18 +9,56 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
//
// struct ExampleElementwiseOp
// {
// template<typename Y, typename X>
// __host__ __device__ constexpr void
// operator()(Y&, const X) const;
//
// template<>
// __host__ __device__ constexpr void
// operator()<half_t, half_t>(half_t& y, const half_t& x) const
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd
{
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
half_t a = x0 + x1;
half_t b = a > 0 ? a : 0;
y = b + x2;
}
__host__ __device__ constexpr void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
const float& x0,
const float& x1,
const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
......@@ -28,8 +66,9 @@ struct AddReluAdd
y = c;
}
__host__ __device__ constexpr void
operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
......@@ -40,8 +79,14 @@ struct AddReluAdd
struct AddHardswishAdd
{
__host__ __device__ constexpr void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
const float& x0,
const float& x1,
const float& x2) const
{
float a = x0 + x1;
float b = a + float{3};
......@@ -50,8 +95,9 @@ struct AddHardswishAdd
y = d;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a + float{3};
......@@ -66,7 +112,7 @@ struct AddHardswishAdd
struct AddAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
__host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const;
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
......@@ -92,6 +138,7 @@ struct AddAddFastGelu
struct Normalize
{
// FIXME: is double absolutely necessary?
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T>
......@@ -126,6 +173,7 @@ struct Normalize
y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
};
// FIXME: is double absolutely necessary?
double epsilon_;
};
......@@ -138,7 +186,7 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
};
}
};
template <>
......@@ -147,7 +195,7 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::type_convert<ck::bhalf_t, float>(x);
};
}
};
} // namespace element_wise
......
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment