Commit f86e4436 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge elementwise op with type conversion

parent 845efff7
...@@ -208,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -208,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{ auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v;
src_element_op_(src_v, src_vector_container.template AsType<SrcData>()[i]);
src_vector_container.template AsType<SrcData>()(i) = src_v;
});
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>( .template SetAsType<src_vector_t>(
...@@ -346,16 +337,25 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -346,16 +337,25 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
}); });
} }
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// pick the right conversion method // if elementwise op does conversion, use the op instead of type_convert
#if CK_EXPERIMENTAL_CONVERT_PRECISION if constexpr(is_same<SrcElementwiseOperation,
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvertPrecision; ck::tensor_operation::element_wise::UnaryConvert>::value ||
#else is_same<SrcElementwiseOperation,
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; ck::tensor_operation::element_wise::UnaryConvertPrecision>::value)
#endif {
// convert from SrcData to DstData here DstData dst_v;
UnaryConvert{}(dst_thread_scratch_(idx), src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
src_thread_scratch_tuple_[thread_scratch_id][idx]); dst_thread_scratch_(idx) = dst_v;
}
// else apply elementwise op and use type_convert for conversion
else
{
SrcData src_v;
src_element_op_(src_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = type_convert<DstData>(src_v);
}
}); });
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment