Commit 3b0c7dc8 authored by Umang Yadav's avatar Umang Yadav
Browse files

Works with GCC

parent a298c926
......@@ -22,12 +22,13 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
......@@ -428,6 +429,7 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>&
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
......@@ -602,5 +604,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
} // namespace std
// =================================================================================================
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
......@@ -22,11 +22,12 @@
#ifndef MIGRAPHX_HIP_FP8_IMPL_HPP
#define MIGRAPHX_HIP_FP8_IMPL_HPP
#if !defined(__GNUC__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx_hip_f8_impl {
namespace detail {
template <bool B, class T, class F>
......@@ -55,17 +56,25 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
static_assert(wm + we == 7, "wm+we==7");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
if(sizeof(T) == 4)
x = reinterpret_cast<uint32_t&>(_x);
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
#if defined(__GNUC__) and !defined(__clang__)
if constexpr(sizeof(T) == 4)
x = CONST_FOLD(*reinterpret_cast<uint32_t*>(&_x));
else
x = CONST_FOLD(*reinterpret_cast<uint16_t*>(&_x));
#else
if constexpr(sizeof(T) == 4)
x = __builtin_bit_cast(uint32_t, _x);
else
x = reinterpret_cast<uint16_t&>(_x);
x = __builtin_bit_cast(uint16_t, _x);
#endif
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if(sizeof(T) == 4)
if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
......@@ -233,14 +242,22 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
constexpr int wmo = 23;
T fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
uint32_t ifInf = 0x7F800000;
uint32_t ifNegInf = 0xFF800000;
uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000;
#if defined(__GNUC__) and !defined(__clang__)
fInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifInf)));
fNegInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifNegInf)));
fNaN = CONST_FOLD(*(reinterpret_cast<float*>(&ifNaN)));
fNeg0 = CONST_FOLD(*(reinterpret_cast<float*>(&ifNeg0)));
#else
// TODO: need to change T for half but right now it would never be called with half
fInf = __builtin_bit_cast(float, ifInf);
fNegInf = __builtin_bit_cast(float, ifNegInf);
fNaN = __builtin_bit_cast(float, ifNaN);
fNeg0 = __builtin_bit_cast(float, ifNeg0);
#endif
if(x == 0)
return 0;
......@@ -288,9 +305,15 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa;
else
retval = (sign << 31) | (exponent << 23) | mantissa;
return reinterpret_cast<const T&>(retval);
#if defined(__GNUC__) and !defined(__clang__)
return CONST_FOLD(*reinterpret_cast<T*>(&retval));
#else
return __builtin_bit_cast(T, retval);
#endif
}
} // namespace migraphx_hip_f8_impl
#if !defined(__GNUC__)
#pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_HIP_FP8_IMPL_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment