Commit 8319e01f authored by Umang Yadav's avatar Umang Yadav
Browse files

Fix tidy

parent ab653aff
...@@ -21,8 +21,13 @@ ...@@ -21,8 +21,13 @@
* ************************************************************************ */ * ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) #define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx { namespace migraphx {
...@@ -39,4 +44,7 @@ inline constexpr To bit_cast(From fr) noexcept ...@@ -39,4 +44,7 @@ inline constexpr To bit_cast(From fr) noexcept
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and // We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// shouldn't be used without having enough tests. // shouldn't be used without having enough tests.
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast // logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// NOLINTNEXTLINE
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#include <cmath> #include <cmath>
...@@ -173,6 +174,7 @@ struct float8 ...@@ -173,6 +174,7 @@ struct float8
} }
} }
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ #define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& operator unary_op(const float8& rhs) \ constexpr float8& operator unary_op(const float8& rhs) \
{ \ { \
...@@ -192,8 +194,8 @@ struct float8 ...@@ -192,8 +194,8 @@ struct float8
MIGRAPHX_FP8_UNARY_OP(+=, +) MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /) MIGRAPHX_FP8_UNARY_OP(/=, /)
inline constexpr float8& operator=(const float8& rhs) = default; inline constexpr float8& operator=(const float8& rhs) = default;
inline constexpr float8& operator=(float8&& rhs) = default; inline constexpr float8& operator=(float8&& rhs) noexcept = default;
inline constexpr float8& operator=(float rhs) inline constexpr float8& operator=(float rhs)
{ {
...@@ -203,11 +205,9 @@ struct float8 ...@@ -203,11 +205,9 @@ struct float8
inline constexpr bool operator==(const float8& rhs) const inline constexpr bool operator==(const float8& rhs) const
{ {
if(rhs.is_zero() and this->is_zero()) if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return true;
else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false; return false;
else if(this->data == rhs.data) else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true; return true;
return false; return false;
} }
...@@ -260,7 +260,7 @@ MIGRAPHX_FP8_BINARY_OP(!=, bool) ...@@ -260,7 +260,7 @@ MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx::fp8::f8_type T> template <migraphx::fp8::f8_type T>
inline migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v) inline migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{ {
v.data = v.data & 0x7f; v.data = v.data & 0x7f; // NOLINT
return v; return v;
} }
...@@ -277,7 +277,7 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -277,7 +277,7 @@ class numeric_limits<fp8e4m3fnuz>
public: public:
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); }
...@@ -294,7 +294,7 @@ class numeric_limits<fp8e4m3fn> ...@@ -294,7 +294,7 @@ class numeric_limits<fp8e4m3fn>
public: public:
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
...@@ -312,7 +312,10 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -312,7 +312,10 @@ class numeric_limits<fp8e5m2fnuz>
public: public:
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT
{
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
...@@ -328,7 +331,7 @@ class numeric_limits<fp8e5m2> ...@@ -328,7 +331,7 @@ class numeric_limits<fp8e5m2>
public: public:
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT
static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
...@@ -345,8 +348,8 @@ class numeric_limits<fp8e5m2> ...@@ -345,8 +348,8 @@ class numeric_limits<fp8e5m2>
// ================================================================================================= // =================================================================================================
// define numeric limits for the new data type // define numeric limits for the new data type
// NOLINTBEGIN
namespace std { namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ #define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return x.is_inf(); } \ inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \ inline bool isnan(T x) { return x.is_nan(); } \
...@@ -372,8 +375,8 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) ...@@ -372,8 +375,8 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std } // namespace std
// NOLINTEND
// ================================================================================================= // =================================================================================================
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
......
...@@ -30,111 +30,91 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -30,111 +30,91 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 { namespace fp8 {
namespace impl { namespace impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> // NOLINTBEGIN
constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan, bool Clip>
constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
{ {
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now // half is not supported for now
constexpr bool is_half = false; constexpr bool is_half = false;
static_assert(wm + we == 7, "wm+we==7"); static_assert(Wm + We == 7, "Wm+We==7");
static_assert(is_float or is_half, "Only float can be cast to f8"); static_assert(is_float or is_half, "Only float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x; typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(_x); x = migraphx::bit_cast<uint32_t>(f_x);
else else
x = migraphx::bit_cast<uint16_t>(_x); x = migraphx::bit_cast<uint16_t>(f_x);
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
uint32_t head = 0;
uint32_t mantissa = 0;
int exponent = 0;
uint32_t bias = 0;
uint32_t sign = 0;
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
{ {
head = x & 0xFF800000; head = x & 0xFF800000; // NOLINT
mantissa = x & 0x7FFFFF; mantissa = x & 0x7FFFFF; // NOLINT
exponent = (head >> 23) & 0xFF; exponent = (head >> 23) & 0xFF; // NOLINT
sign = head >> 31; sign = head >> 31; // NOLINT
bias = 127; bias = 127;
} }
else else
{ {
head = x & 0xFC00; head = x & 0xFC00; // NOLINT
mantissa = x & 0x3FF; mantissa = x & 0x3FF; // NOLINT
exponent = (head >> 10) & 0x1F; exponent = (head >> 10) & 0x1F; // NOLINT
sign = head >> 15; sign = head >> 15; // NOLINT
bias = 15; bias = 15;
} }
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); // NOLINT
uint32_t signed_all_ones = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); // NOLINT
// Calcualte maximum singed value FLT_MAX, FLT_MIN // Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t signed_max = signed_all_ones; uint32_t signed_max = signed_all_ones;
if(not negative_zero_nan) if(not NegativeZeroNan)
{ signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1);
signed_max = (wm == 2) ? (signed_max - 4) : (signed_max - 1);
}
// Deal with inf and NaNs // Deal with inf and NaNs
if(negative_zero_nan) // For the FNUZ cases, it is simple just return NaNs if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs
{ {
if(sizeof(T) == 4) if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT
{ (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT
if((x & 0x7F800000) == 0x7F800000) return 0x80;
return 0x80;
}
else
{
if((x & 0x7C00) == 0x7C00)
return 0x80;
}
} }
else else
{ {
// calculate most common NaN mantissa for FP8, which is all Ones in binary // calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t nan_mantissa = 1; uint32_t nan_mantissa = 1;
for(auto i = 1; i < wm; ++i) for(auto i = 1; i < Wm; ++i)
{ {
nan_mantissa |= (nan_mantissa << 1); nan_mantissa |= (nan_mantissa << 1); // NOLINT
} }
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT
{ {
// infinity // infinity
if(mantissa == 0) if(mantissa == 0)
{ {
if(sign == 0) if(sign == 0)
{ return (Wm == 2) ? 0x7B : 0x7E;
return (wm == 2) ? 0x7B : 0x7E;
}
else else
{ return (Wm == 2) ? 0xFB : 0xFE;
return (wm == 2) ? 0xFB : 0xFE;
}
} }
else else // NaNs
{ // NaNs
return signed_inf + nan_mantissa; return signed_inf + nan_mantissa;
}
} }
} }
// handle positive zero // handle positive zero
if(x == 0) if(x == 0)
return 0; return 0;
// handle negative zero // handle negative zero
if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{ {
if(negative_zero_nan) // For FNUZ types neg zero is just positive zero return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
{
return 0;
}
else
{
return 0x80;
}
} }
/* First need to check if it is normal or denorm as there is a difference of implict 1 /* First need to check if it is normal or denorm as there is a difference of implict 1
...@@ -144,13 +124,15 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) ...@@ -144,13 +124,15 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
exponent and mantissa again*/ exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); // NOLINT
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/ the difference needs to be adjusted and mantissa shifted*/
int act_exponent, f8_exponent, exponent_diff; int act_exponent = 0;
int f8_exponent = 0;
int exponent_diff = 0;
if(exponent == 0) if(exponent == 0)
{ // fp32/fp16 is in denormal. { // fp32/fp16 is in denormal.
...@@ -182,11 +164,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) ...@@ -182,11 +164,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
0; // exponent_diff=0 does not mean there is no difference for this case, 0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa // act_exponent could be larger. Just that it does not need shift mantissa
} }
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa mantissa += (1u << mfmt); // Add the implicit 1 into mantissa
} }
// NOLINTNEXTLINE
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == bool midpoint = (mantissa & ((1 << (mfmt - Wm + exponent_diff)) - 1)) ==
(1 << (mfmt - wm + exponent_diff - 1)); (1 << (mfmt - Wm + exponent_diff - 1)); // NOLINT
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
...@@ -194,64 +176,58 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) ...@@ -194,64 +176,58 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
*/ */
if(exponent_diff > 0) if(exponent_diff > 0)
mantissa >>= exponent_diff; mantissa >>= exponent_diff; // NOLINT
else if(exponent_diff == -1) else if(exponent_diff == -1)
mantissa <<= -exponent_diff; mantissa <<= -exponent_diff; // NOLINT
bool implicit_one = mantissa & (1 << mfmt); bool implicit_one = mantissa & (1 << mfmt); // NOLINT
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent = f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted // Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1; uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; // NOLINT
bool odd = bool odd =
mantissa & (1 << (mfmt - wm)); // if the least significant bit that is not truncated is 1 mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & // NOLINT
drop_mask; // NOLINT
// Now we deal with overflow // Now we deal with overflow
if(f8_exponent == 0) if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) // NOLINT
{ {
if((1 << mfmt) & mantissa) f8_exponent = 1; // denormal overflow to become normal, promote exponent
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
} }
else else if((1 << (mfmt + 1)) & mantissa) // NOLINT
{ {
if((1 << (mfmt + 1)) & mantissa) mantissa >>= 1; // NOLINT
{ f8_exponent++;
mantissa >>= 1;
f8_exponent++;
}
} }
mantissa >>= (mfmt - wm); mantissa >>= (mfmt - Wm); // NOLINT
// above range: quantize to maximum possible float of the same sign // above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); const int max_exp = (1 << We) - (NegativeZeroNan ? 1 : 2); // NOLINT
if(f8_exponent > max_exp) if(f8_exponent > max_exp)
{ {
if(clip) if(Clip)
{
return signed_max; return signed_max;
}
else else
{ {
// https://onnx.ai/onnx/technical/float8.html#cast // https://onnx.ai/onnx/technical/float8.html#cast
if(negative_zero_nan) if(NegativeZeroNan)
return 0x80; return 0x80;
else else
return (wm == 2) ? signed_inf : signed_all_ones; return (Wm == 2) ? signed_inf : signed_all_ones;
} }
} }
if(f8_exponent == 0 and mantissa == 0) if(f8_exponent == 0 and mantissa == 0)
return negative_zero_nan ? 0 : (sign << 7); return NegativeZeroNan ? 0 : (sign << 7); // NOLINT
mantissa &= (1 << wm) - 1; mantissa &= (1 << Wm) - 1; // NOLINT
return (sign << 7) | (f8_exponent << wm) | mantissa; return (sign << 7) | (f8_exponent << Wm) | mantissa; // NOLINT
} }
// NOLINTEND
template <int wm, int we, typename T, bool negative_zero_nan> template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan>
constexpr T cast_from_f8(uint8_t x) constexpr T cast_from_f8(uint8_t x)
{ {
// half is not supported for now // half is not supported for now
...@@ -261,69 +237,70 @@ constexpr T cast_from_f8(uint8_t x) ...@@ -261,69 +237,70 @@ constexpr T cast_from_f8(uint8_t x)
constexpr int weo = is_half ? 5 : 8; constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
// NOLINTNEXTLINE
T fInf, fNegInf, fNaN, fNeg0; T f_inf, f_neg_inf, f_nan, f_neg0;
if constexpr(is_float) if constexpr(is_float)
{ {
const uint32_t ifInf = 0x7F800000; const uint32_t if_inf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000; const uint32_t if_neg_inf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001; const uint32_t if_nan = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000; const uint32_t if_neg0 = 0x80000000;
fInf = migraphx::bit_cast<float>(ifInf); f_inf = migraphx::bit_cast<float>(if_inf);
fNegInf = migraphx::bit_cast<float>(ifNegInf); f_neg_inf = migraphx::bit_cast<float>(if_neg_inf);
fNaN = migraphx::bit_cast<float>(ifNaN); f_nan = migraphx::bit_cast<float>(if_nan);
fNeg0 = migraphx::bit_cast<float>(ifNeg0); f_neg0 = migraphx::bit_cast<float>(if_neg0);
} }
if(x == 0) if(x == 0)
return 0; return 0;
uint32_t sign = x >> 7; uint32_t sign = x >> 7; // NOLINT
uint32_t mantissa = x & ((1 << wm) - 1); uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT
int exponent = (x & 0x7F) >> wm; int exponent = (x & 0x7F) >> Wm; // NOLINT
if(negative_zero_nan) if(NegativeZeroNan)
{ {
if(x == 0x80) if(x == 0x80)
return fNaN; return f_nan;
} }
else else
{ {
if(x == 0x80) if(x == 0x80)
return fNeg0; return f_neg0;
if(exponent == ((1 << we) - 1) and wm == 2) if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan;
else if(wm == 3 and (x == 0x7F or x == 0xFF)) else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return fNaN; return f_nan;
} }
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - wm); int sh = 1 + __builtin_clz(mantissa) - (32 - Wm);
mantissa <<= sh; mantissa <<= sh; // NOLINT
exponent += 1 - sh; exponent += 1 - sh;
mantissa &= ((1 << wm) - 1); mantissa &= ((1 << Wm) - 1); // NOLINT
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm; mantissa <<= wmo - Wm; // NOLINT
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << wmo; mantissa |= 1 << wmo; // NOLINT
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent; // NOLINT
exponent = 0; exponent = 0;
} }
if(sizeof(T) == 2) if(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT
else else
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT
return migraphx::bit_cast<T>(retval); return migraphx::bit_cast<T>(retval);
} }
......
...@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::tuple_type: case shape::tuple_type:
case shape::bool_type: case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
......
...@@ -134,6 +134,15 @@ TEST_CASE(test_negative_zero) ...@@ -134,6 +134,15 @@ TEST_CASE(test_negative_zero)
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
} }
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
......
...@@ -331,6 +331,15 @@ TEST_CASE(test_negative_zero) ...@@ -331,6 +331,15 @@ TEST_CASE(test_negative_zero)
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
} }
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment