Commit 3b0c7dc8 authored by Umang Yadav's avatar Umang Yadav
Browse files

Works with GCC

parent a298c926
...@@ -22,12 +22,13 @@ ...@@ -22,12 +22,13 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP #define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined" #pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions" #pragma clang diagnostic ignored "-Wc++20-extensions"
#endif
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__)) #if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__ // need to include hip_runtime.h otherwise it complains about __host__ and __device__
...@@ -428,6 +429,7 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& ...@@ -428,6 +429,7 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>&
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \ return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
} }
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>) MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>) MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>) MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
...@@ -602,5 +604,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz> ...@@ -602,5 +604,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
} // namespace std } // namespace std
// ================================================================================================= // =================================================================================================
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP #endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
...@@ -22,11 +22,12 @@ ...@@ -22,11 +22,12 @@
#ifndef MIGRAPHX_HIP_FP8_IMPL_HPP #ifndef MIGRAPHX_HIP_FP8_IMPL_HPP
#define MIGRAPHX_HIP_FP8_IMPL_HPP #define MIGRAPHX_HIP_FP8_IMPL_HPP
#if !defined(__GNUC__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier" #pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx_hip_f8_impl { namespace migraphx_hip_f8_impl {
namespace detail { namespace detail {
template <bool B, class T, class F> template <bool B, class T, class F>
...@@ -55,17 +56,25 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t ...@@ -55,17 +56,25 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x; typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if(sizeof(T) == 4)
x = reinterpret_cast<uint32_t&>(_x); #if defined(__GNUC__) and !defined(__clang__)
if constexpr(sizeof(T) == 4)
x = CONST_FOLD(*reinterpret_cast<uint32_t*>(&_x));
else else
x = reinterpret_cast<uint16_t&>(_x); x = CONST_FOLD(*reinterpret_cast<uint16_t*>(&_x));
#else
if constexpr(sizeof(T) == 4)
x = __builtin_bit_cast(uint32_t, _x);
else
x = __builtin_bit_cast(uint16_t, _x);
#endif
uint32_t head, mantissa; uint32_t head, mantissa;
int exponent, bias; int exponent, bias;
uint32_t sign; uint32_t sign;
if(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
{ {
head = x & 0xFF800000; head = x & 0xFF800000;
mantissa = x & 0x7FFFFF; mantissa = x & 0x7FFFFF;
...@@ -233,14 +242,22 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -233,14 +242,22 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
constexpr int wmo = 23; constexpr int wmo = 23;
T fInf, fNegInf, fNaN, fNeg0; T fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000; uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000; uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001; uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000; uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf); #if defined(__GNUC__) and !defined(__clang__)
fNegInf = reinterpret_cast<const float&>(ifNegInf); fInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifInf)));
fNaN = reinterpret_cast<const float&>(ifNaN); fNegInf = CONST_FOLD(*(reinterpret_cast<float*>(&ifNegInf)));
fNeg0 = reinterpret_cast<const float&>(ifNeg0); fNaN = CONST_FOLD(*(reinterpret_cast<float*>(&ifNaN)));
fNeg0 = CONST_FOLD(*(reinterpret_cast<float*>(&ifNeg0)));
#else
// TODO: need to change T for half but right now it would never be called with half
fInf = __builtin_bit_cast(float, ifInf);
fNegInf = __builtin_bit_cast(float, ifNegInf);
fNaN = __builtin_bit_cast(float, ifNaN);
fNeg0 = __builtin_bit_cast(float, ifNeg0);
#endif
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -288,9 +305,15 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) ...@@ -288,9 +305,15 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa;
else else
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa;
return reinterpret_cast<const T&>(retval); #if defined(__GNUC__) and !defined(__clang__)
return CONST_FOLD(*reinterpret_cast<T*>(&retval));
#else
return __builtin_bit_cast(T, retval);
#endif
} }
} // namespace migraphx_hip_f8_impl } // namespace migraphx_hip_f8_impl
#if !defined(__GNUC__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_HIP_FP8_IMPL_HPP #endif // MIGRAPHX_HIP_FP8_IMPL_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment