Commit 5195dbbb authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Move type_convert to elemwise op, update the op

parent f86e4436
...@@ -56,6 +56,12 @@ struct PassThrough ...@@ -56,6 +56,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
...@@ -86,51 +92,27 @@ struct UnaryConvert ...@@ -86,51 +92,27 @@ struct UnaryConvert
} }
}; };
struct UnaryConvertPrecision struct ConvertBF16RTN
{ {
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert_precision<float>(x);
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = type_convert_precision<half_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
y = type_convert_precision<bhalf_t>(x); y = type_convert_precision<bhalf_t>(x);
} }
template <> template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert_precision<double>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = type_convert_precision<int8_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{ {
y = type_convert_precision<bhalf_t>(x); y = type_convert_precision<bhalf_t>(x);
} }
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
y = type_convert_precision<bhalf_t>(x); y = type_convert_precision<half_t>(x);
} }
}; };
......
...@@ -339,23 +339,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -339,23 +339,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// if elementwise op does conversion, use the op instead of type_convert // apply the src elementwise op and convert under the hood if needed
if constexpr(is_same<SrcElementwiseOperation, DstData dst_v;
ck::tensor_operation::element_wise::UnaryConvert>::value || src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
is_same<SrcElementwiseOperation, dst_thread_scratch_(idx) = dst_v;
ck::tensor_operation::element_wise::UnaryConvertPrecision>::value)
{
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
}
// else apply elementwise op and use type_convert for conversion
else
{
SrcData src_v;
src_element_op_(src_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = type_convert<DstData>(src_v);
}
}); });
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment