Commit c390260d authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 428f6fd2
...@@ -304,6 +304,20 @@ struct AddFastGelu ...@@ -304,6 +304,20 @@ struct AddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c + d;
float x1_f;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -221,6 +221,20 @@ struct AddAddFastGelu ...@@ -221,6 +221,20 @@ struct AddAddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{
const float x0_f = c + d0 + d1;
float x1_f;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
}; };
struct Normalize struct Normalize
......
...@@ -253,7 +253,6 @@ struct FastGelu ...@@ -253,7 +253,6 @@ struct FastGelu
y = x * cdf; y = x * cdf;
} }
// device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment