Commit d7339e8a authored by Umang Yadav's avatar Umang Yadav
Browse files

use bit cast

parent a9dd42f7
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp>
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier" #pragma clang diagnostic ignored "-Wreserved-identifier"
#endif #endif
#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx { namespace migraphx {
namespace detail { namespace detail {
template <bool B, class T, class F> template <bool B, class T, class F>
...@@ -42,26 +42,10 @@ struct conditional<false, T, F> ...@@ -42,26 +42,10 @@ struct conditional<false, T, F>
using type = F; using type = F;
}; };
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
To x = CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
To x = __builtin_bit_cast(To, fr);
#endif
return x;
}
} // namespace detail } // namespace detail
namespace fp8 { namespace fp8 {
namespace impl { namespace impl {
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
// #else
// __host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
// #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
...@@ -73,9 +57,9 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t ...@@ -73,9 +57,9 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
typename migraphx::detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x; typename migraphx::detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
x = migraphx::detail::bit_cast<uint32_t>(_x); x = migraphx::bit_cast<uint32_t>(_x);
else else
x = migraphx::detail::bit_cast<uint16_t>(_x); x = migraphx::bit_cast<uint16_t>(_x);
uint32_t head, mantissa; uint32_t head, mantissa;
int exponent, bias; int exponent, bias;
...@@ -267,10 +251,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -267,10 +251,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
uint32_t ifNaN = 0x7F800001; uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000; uint32_t ifNeg0 = 0x80000000;
// TODO: need to change T for half but right now it would never called with half // TODO: need to change T for half but right now it would never called with half
fInf = migraphx::detail::bit_cast<float>(ifInf); fInf = migraphx::bit_cast<float>(ifInf);
fNegInf = migraphx::detail::bit_cast<float>(ifNegInf); fNegInf = migraphx::bit_cast<float>(ifNegInf);
fNaN = migraphx::detail::bit_cast<float>(ifNaN); fNaN = migraphx::bit_cast<float>(ifNaN);
fNeg0 = migraphx::detail::bit_cast<float>(ifNeg0); fNeg0 = migraphx::bit_cast<float>(ifNeg0);
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -318,7 +302,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -318,7 +302,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa;
else else
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa;
return migraphx::detail::bit_cast<T>(retval); return migraphx::bit_cast<T>(retval);
} }
} // namespace impl } // namespace impl
} // namespace fp8 } // namespace fp8
......
...@@ -144,7 +144,7 @@ extern "C" { ...@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p) __global__ void kernel(${type}* p)
{ {
auto x = *p; auto x = *p;
*p = implicit_conversion(migraphx::${invoke}); *p = migraphx::implicit_conversion(migraphx::${invoke});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment