Commit 34348bd1 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Add a flag to pick converion method

parent 2797227f
...@@ -173,6 +173,12 @@ ...@@ -173,6 +173,12 @@
#define CK_WORKAROUND_DENORM_FIX 0 #define CK_WORKAROUND_DENORM_FIX 0
#endif #endif
// flag to enable high precision data conversion
// 0 - fast, 1 - high precision
#ifndef CK_EXPERIMENTAL_CONVERT_PRECISION
#define CK_EXPERIMENTAL_CONVERT_PRECISION 1
#endif
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -89,15 +89,48 @@ struct UnaryConvert ...@@ -89,15 +89,48 @@ struct UnaryConvert
struct UnaryConvertPrecision struct UnaryConvertPrecision
{ {
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
y = type_convert<Y>(x); y = type_convert_precision<float>(x);
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = type_convert_precision<half_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = type_convert_precision<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = type_convert_precision<double>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = type_convert_precision<int8_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert_precision<bhalf_t>(x);
} }
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
y = type_convert_bf16_rtn(x); y = type_convert_precision<bhalf_t>(x);
} }
}; };
......
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp" #include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck { namespace ck {
namespace detail { namespace detail {
...@@ -348,9 +347,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -348,9 +347,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}); });
} }
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// pick the right conversion method
#if CK_EXPERIMENTAL_CONVERT_PRECISION
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvertPrecision;
#else
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert;
#endif
// convert from SrcData to DstData here // convert from SrcData to DstData here
ck::tensor_operation::element_wise::UnaryConvert{}( UnaryConvert{}(dst_thread_scratch_(idx), src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx), src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
#endif #endif
} }
......
...@@ -1031,8 +1031,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -1031,8 +1031,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
// Convert X to Y with highest possible precision
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_precision(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert fp32 to bf16 with RTN if higher precision is needed // Convert fp32 to bf16 with RTN if higher precision is needed
__host__ __device__ constexpr bhalf_t type_convert_bf16_rtn(float x) template <>
inline __host__ __device__ constexpr bhalf_t type_convert_precision<bhalf_t, float>(float x)
{ {
union union
{ {
...@@ -1074,6 +1084,15 @@ __host__ __device__ constexpr bhalf_t type_convert_bf16_rtn(float x) ...@@ -1074,6 +1084,15 @@ __host__ __device__ constexpr bhalf_t type_convert_bf16_rtn(float x)
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert_precision<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert_precision<bhalf_t>(x_fp32);
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment