Commit c390260d authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 428f6fd2
......@@ -304,6 +304,20 @@ struct AddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c + d;
float x1_f;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
};
} // namespace element_wise
......
......@@ -221,6 +221,20 @@ struct AddAddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{
const float x0_f = c + d0 + d1;
float x1_f;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
};
struct Normalize
......
......@@ -253,7 +253,6 @@ struct FastGelu
y = x * cdf;
}
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment