Commit dc9c9784 authored by Umang Yadav's avatar Umang Yadav
Browse files

make bit_cast a function

parent 770b632d
...@@ -41,6 +41,18 @@ struct conditional<false, T, F> ...@@ -41,6 +41,18 @@ struct conditional<false, T, F>
{ {
using type = F; using type = F;
}; };
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
To x = CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
To x = __builtin_bit_cast(To, fr);
#endif
return x;
}
} // namespace detail } // namespace detail
// #ifdef __HIP_PLATFORM_HCC__ // #ifdef __HIP_PLATFORM_HCC__
...@@ -58,17 +70,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t ...@@ -58,17 +70,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x; typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
#if defined(__GNUC__) and !defined(__clang__)
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
x = CONST_FOLD(*reinterpret_cast<uint32_t*>(&_x)); x = detail::bit_cast<uint32_t>(_x);
else else
x = CONST_FOLD(*reinterpret_cast<uint16_t*>(&_x)); x = detail::bit_cast<uint16_t>(_x);
#else
if constexpr(sizeof(T) == 4)
x = __builtin_bit_cast(uint32_t, _x);
else
x = __builtin_bit_cast(uint16_t, _x);
#endif
uint32_t head, mantissa; uint32_t head, mantissa;
int exponent, bias; int exponent, bias;
...@@ -246,18 +251,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -246,18 +251,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
uint32_t ifNegInf = 0xFF800000; uint32_t ifNegInf = 0xFF800000;
uint32_t ifNaN = 0x7F800001; uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000; uint32_t ifNeg0 = 0x80000000;
#if defined(__GNUC__) and !defined(__clang__)
fInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifInf)));
fNegInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifNegInf)));
fNaN = CONST_FOLD(*(reinterpret_cast<float*>(&ifNaN)));
fNeg0 = CONST_FOLD(*(reinterpret_cast<float*>(&ifNeg0)));
#else
// TODO: need to change T for half but right now it would never be called with half // TODO: need to change T for half but right now it would never be called with half
fInf = __builtin_bit_cast(float, ifInf); fInf = detail::bit_cast<float>(ifInf);
fNegInf = __builtin_bit_cast(float, ifNegInf); fNegInf = detail::bit_cast<float>(ifNegInf);
fNaN = __builtin_bit_cast(float, ifNaN); fNaN = detail::bit_cast<float>(ifNaN);
fNeg0 = __builtin_bit_cast(float, ifNeg0); fNeg0 = detail::bit_cast<float>(ifNeg0);
#endif
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -305,11 +303,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -305,11 +303,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa;
else else
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa;
#if defined(__GNUC__) and !defined(__clang__) return detail::bit_cast<T>(retval);
return CONST_FOLD(*reinterpret_cast<T*>(&retval));
#else
return __builtin_bit_cast(T, retval);
#endif
} }
} // namespace migraphx_hip_f8_impl } // namespace migraphx_hip_f8_impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment