"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "68c17b1b0e618fd64bd909b2f30759d0fe608145"
Commit d7339e8a authored by Umang Yadav's avatar Umang Yadav
Browse files

use bit cast

parent a9dd42f7
......@@ -22,12 +22,12 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp>
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
namespace detail {
template <bool B, class T, class F>
......@@ -42,26 +42,10 @@ struct conditional<false, T, F>
using type = F;
};
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
To x = CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
To x = __builtin_bit_cast(To, fr);
#endif
return x;
}
} // namespace detail
namespace fp8 {
namespace impl {
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
// #else
// __host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
// #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
......@@ -73,9 +57,9 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
typename migraphx::detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4)
x = migraphx::detail::bit_cast<uint32_t>(_x);
x = migraphx::bit_cast<uint32_t>(_x);
else
x = migraphx::detail::bit_cast<uint16_t>(_x);
x = migraphx::bit_cast<uint16_t>(_x);
uint32_t head, mantissa;
int exponent, bias;
......@@ -267,10 +251,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000;
// TODO: need to change T for half but right now it would never called with half
fInf = migraphx::detail::bit_cast<float>(ifInf);
fNegInf = migraphx::detail::bit_cast<float>(ifNegInf);
fNaN = migraphx::detail::bit_cast<float>(ifNaN);
fNeg0 = migraphx::detail::bit_cast<float>(ifNeg0);
fInf = migraphx::bit_cast<float>(ifInf);
fNegInf = migraphx::bit_cast<float>(ifNegInf);
fNaN = migraphx::bit_cast<float>(ifNaN);
fNeg0 = migraphx::bit_cast<float>(ifNeg0);
if(x == 0)
return 0;
......@@ -318,7 +302,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa;
else
retval = (sign << 31) | (exponent << 23) | mantissa;
return migraphx::detail::bit_cast<T>(retval);
return migraphx::bit_cast<T>(retval);
}
} // namespace impl
} // namespace fp8
......
......@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p = implicit_conversion(migraphx::${invoke});
*p = migraphx::implicit_conversion(migraphx::${invoke});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment