"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "618dca0895f5f4ede19f5feebb064648e128e12e"
Commit 5195dbbb authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Move type_convert to elemwise op, update the op

parent f86e4436
......@@ -56,6 +56,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......@@ -86,51 +92,27 @@ struct UnaryConvert
}
};
struct UnaryConvertPrecision
struct ConvertBF16RTN
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = type_convert_precision<float>(x);
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = type_convert_precision<half_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert_precision<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = type_convert_precision<double>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = type_convert_precision<int8_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert_precision<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = type_convert_precision<bhalf_t>(x);
y = type_convert_precision<half_t>(x);
}
};
......
......@@ -339,23 +339,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
static_ford<SliceLengths>{}([&](auto idx) {
// if elementwise op does conversion, use the op instead of type_convert
if constexpr(is_same<SrcElementwiseOperation,
ck::tensor_operation::element_wise::UnaryConvert>::value ||
is_same<SrcElementwiseOperation,
ck::tensor_operation::element_wise::UnaryConvertPrecision>::value)
{
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
}
// else apply elementwise op and use type_convert for conversion
else
{
SrcData src_v;
src_element_op_(src_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = type_convert<DstData>(src_v);
}
// apply the src elementwise op and convert under the hood if needed
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment