Commit dc9c9784 authored by Umang Yadav's avatar Umang Yadav
Browse files

make bit_cast a function

parent 770b632d
......@@ -41,6 +41,18 @@ struct conditional<false, T, F>
{
using type = F;
};
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
To x = CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
To x = __builtin_bit_cast(To, fr);
#endif
return x;
}
} // namespace detail
// #ifdef __HIP_PLATFORM_HCC__
......@@ -58,17 +70,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
#if defined(__GNUC__) and !defined(__clang__)
if constexpr(sizeof(T) == 4)
x = CONST_FOLD(*reinterpret_cast<uint32_t*>(&_x));
x = detail::bit_cast<uint32_t>(_x);
else
x = CONST_FOLD(*reinterpret_cast<uint16_t*>(&_x));
#else
if constexpr(sizeof(T) == 4)
x = __builtin_bit_cast(uint32_t, _x);
else
x = __builtin_bit_cast(uint16_t, _x);
#endif
x = detail::bit_cast<uint16_t>(_x);
uint32_t head, mantissa;
int exponent, bias;
......@@ -246,18 +251,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
uint32_t ifNegInf = 0xFF800000;
uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000;
#if defined(__GNUC__) and !defined(__clang__)
fInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifInf)));
fNegInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifNegInf)));
fNaN = CONST_FOLD(*(reinterpret_cast<float*>(&ifNaN)));
fNeg0 = CONST_FOLD(*(reinterpret_cast<float*>(&ifNeg0)));
#else
// TODO: need to change T for half but right now it would never be called with half
fInf = __builtin_bit_cast(float, ifInf);
fNegInf = __builtin_bit_cast(float, ifNegInf);
fNaN = __builtin_bit_cast(float, ifNaN);
fNeg0 = __builtin_bit_cast(float, ifNeg0);
#endif
fInf = detail::bit_cast<float>(ifInf);
fNegInf = detail::bit_cast<float>(ifNegInf);
fNaN = detail::bit_cast<float>(ifNaN);
fNeg0 = detail::bit_cast<float>(ifNeg0);
if(x == 0)
return 0;
......@@ -305,11 +303,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa;
else
retval = (sign << 31) | (exponent << 23) | mantissa;
#if defined(__GNUC__) and !defined(__clang__)
return CONST_FOLD(*reinterpret_cast<T*>(&retval));
#else
return __builtin_bit_cast(T, retval);
#endif
return detail::bit_cast<T>(retval);
}
} // namespace migraphx_hip_f8_impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment