"benchmarks/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "e0c6f556e85053059c74ab6b5cee396baf3b4316"
Commit 4e9d51f0 authored by Umang Yadav's avatar Umang Yadav
Browse files

Working FNUZ and FN

parent d9f11e31
...@@ -86,6 +86,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) ...@@ -86,6 +86,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
} }
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
uint32_t signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1));
if(not negative_zero_nan)
{
signed_max = (wm == 2) ? (signed_max - 4) : (signed_max - 1);
}
// Deal with inf and NaNs // Deal with inf and NaNs
if(negative_zero_nan) if(negative_zero_nan)
...@@ -103,15 +108,50 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) ...@@ -103,15 +108,50 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
} }
else else
{ {
if(sizeof(T) == 4) // calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t nan_mantissa = 1;
for(auto i = 1; i < wm; ++i)
{ {
if((x & 0x7F800000) == 0x7F800000) nan_mantissa |= (nan_mantissa << 1);
return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic }
// TODO: abstract duplicate branches
if(sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000))
{
// infinity
if(mantissa == 0)
{
if(sign == 0)
{
return (wm == 2) ? 0x7B : 0x7E;
} }
else else
{ {
if((x & 0x7C00) == 0x7C00) return (wm == 2) ? 0xFB : 0xFE;
return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic }
}
else
{ // NaNs
return signed_inf + nan_mantissa;
}
}
else if(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))
{
// infinity
if(mantissa == 0)
{
if(sign == 0)
{
return (wm == 2) ? 0x7B : 0x7E;
}
else
{
return (wm == 2) ? 0xFB : 0xFE;
}
}
else
{ // NaNs
return signed_inf + nan_mantissa;
}
} }
} }
// handle positive zero // handle positive zero
...@@ -222,16 +262,24 @@ this case, the fp16 mantissa should be shift left by 1 */ ...@@ -222,16 +262,24 @@ this case, the fp16 mantissa should be shift left by 1 */
// above range: quantize to maximum possible float of the same sign // above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
// TODO: this is ugly, need better way to handle out of range values
if(f8_exponent > max_exp) if(f8_exponent > max_exp)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << wm) - 1; return signed_max;
f8_exponent = max_exp;
} }
else else
{ {
return signed_inf; if(negative_zero_nan)
{
return 0x80;
}
else
{
uint32_t tmp_signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1));
return (wm == 2) ? signed_inf : tmp_signed_max;
}
} }
} }
...@@ -273,8 +321,10 @@ constexpr T cast_from_f8(uint8_t x) ...@@ -273,8 +321,10 @@ constexpr T cast_from_f8(uint8_t x)
{ {
if(x == 0x80) if(x == 0x80)
return fNeg0; return fNeg0;
if(exponent == ((1 << we) - 1)) if(exponent == ((1 << we) - 1) and wm == 2)
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
else if(wm == 3 and (x == 0x7F or x == 0xFF))
return fNaN;
} }
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
......
...@@ -79,7 +79,7 @@ struct float8 ...@@ -79,7 +79,7 @@ struct float8
// default constructor // default constructor
constexpr float8() = default; constexpr float8() = default;
// default copy constructor // default copy constructor
constexpr float8(const float8<T>& y) = default; constexpr float8(const float8& y) = default;
struct from_bits_t struct from_bits_t
{ {
}; };
...@@ -149,15 +149,12 @@ struct float8 ...@@ -149,15 +149,12 @@ struct float8
{ {
if(T == migraphx_fp8::f8_type::bf8) if(T == migraphx_fp8::f8_type::bf8)
{ {
return (data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xfd) or return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xfe) or (data == 0xff); (data == 0xFE) or (data == 0xFF);
} }
else else
{ {
return (data == 0x79) or (data == 0x7a) or (data == 0x7b) or (data == 0x7c) or return (data == 0x7F) or (data == 0xFF);
(data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xf9) or
(data == 0xfa) or (data == 0xfb) or (data == 0xfc) or (data == 0xfd) or
(data == 0xfe) or (data == 0xff);
} }
} }
} }
...@@ -172,11 +169,12 @@ struct float8 ...@@ -172,11 +169,12 @@ struct float8
{ {
if(T == migraphx_fp8::f8_type::bf8) if(T == migraphx_fp8::f8_type::bf8)
{ {
return (data == 0x7c) or (data == 0xfc); return (data == 0x7C) or (data == 0xFC);
} }
else else
{ {
return (data == 0x78) or (data == 0xf8); // no infinities in e4m3fn, represent them as NaNs
return (data == 0x7F) or (data == 0xFF);
} }
} }
} }
...@@ -211,12 +209,12 @@ struct float8 ...@@ -211,12 +209,12 @@ struct float8
inline constexpr bool operator==(const float8& rhs) const inline constexpr bool operator==(const float8& rhs) const
{ {
if((rhs.is_zero() and this->is_zero()) or if(rhs.is_zero() and this->is_zero())
(fabs(rhs - *this) < migraphx_fp8::numeric_limits<float8<T, FNUZ>>::epsilon()))
return true; return true;
else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false; return false;
else if(this->data == rhs.data)
return true;
return false; return false;
} }
...@@ -272,8 +270,6 @@ inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v) ...@@ -272,8 +270,6 @@ inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v)
} }
// https://onnx.ai/onnx/technical/float8.html // https://onnx.ai/onnx/technical/float8.html
// these types are not exactly same as GraphCore's FNUZ types. GraphCore's FNUZ types assumes
// exponent bias of 8 and 16 for the FNUZ types, ONNX spec
using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>; using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>; using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>; using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>;
...@@ -282,6 +278,8 @@ using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>; ...@@ -282,6 +278,8 @@ using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
template <> template <>
class numeric_limits<fp8e4m3fnuz> class numeric_limits<fp8e4m3fnuz>
{ {
static constexpr bool has_infinity = false;
public: public:
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
...@@ -292,13 +290,30 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -292,13 +290,30 @@ class numeric_limits<fp8e4m3fnuz>
static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e4m3fn>
{
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } public:
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); }
}; };
template <> template <>
class numeric_limits<fp8e5m2fnuz> class numeric_limits<fp8e5m2fnuz>
{ {
static constexpr bool has_infinity = false;
public: public:
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
...@@ -310,62 +325,56 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -310,62 +325,56 @@ class numeric_limits<fp8e5m2fnuz>
static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); }
}; };
} // namespace migraphx_fp8
// =================================================================================================
// define numeric limits for the new data type
namespace std {
inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{
return x.is_inf();
}
inline bool isfinite(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{
return x.is_inf();
}
inline bool isnan(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{
return x.is_nan();
}
inline bool isnan(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{
return x.is_nan();
}
template <> template <>
class numeric_limits<migraphx_fp8::fp8e4m3fnuz> class numeric_limits<fp8e5m2>
: public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e4m3fnuz>
{ {
}; public:
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); }
template <> static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
class numeric_limits<migraphx_fp8::fp8e5m2fnuz> // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
: public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e5m2fnuz> // this distinction. For the floating points we would end up using lowest most of the times.
{ static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
};
template <class T> static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
struct common_type<migraphx_fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT // 7C and FC both are infinity
{ static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
}; };
} // namespace migraphx_fp8
template <class T> // =================================================================================================
struct common_type<T, migraphx_fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT // define numeric limits for the new data type
{ namespace std {
};
template <> #define MIGRAPHX_FP8_STD_OVERLOADS(T) \
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz> inline bool isfinite(T x) { return x.is_inf(); } \
{ inline bool isnan(T x) { return x.is_nan(); } \
using type = float; template <> \
}; class numeric_limits<T> : public migraphx_fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2fnuz)
} // namespace std } // namespace std
// ================================================================================================= // =================================================================================================
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fn_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0, 0.001953125, 0.00390625, 0.005859375,
0.0078125, 0.009765625, 0.01171875, 0.013671875,
0.015625, 0.017578125, 0.01953125, 0.021484375,
0.0234375, 0.025390625, 0.02734375, 0.029296875,
0.03125, 0.03515625, 0.0390625, 0.04296875,
0.046875, 0.05078125, 0.0546875, 0.05859375,
0.0625, 0.0703125, 0.078125, 0.0859375,
0.09375, 0.1015625, 0.109375, 0.1171875,
0.125, 0.140625, 0.15625, 0.171875,
0.1875, 0.203125, 0.21875, 0.234375,
0.25, 0.28125, 0.3125, 0.34375,
0.375, 0.40625, 0.4375, 0.46875,
0.5, 0.5625, 0.625, 0.6875,
0.75, 0.8125, 0.875, 0.9375,
1.0, 1.125, 1.25, 1.375,
1.5, 1.625, 1.75, 1.875,
2.0, 2.25, 2.5, 2.75,
3.0, 3.25, 3.5, 3.75,
4.0, 4.5, 5.0, 5.5,
6.0, 6.5, 7.0, 7.5,
8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0,
16.0, 18.0, 20.0, 22.0,
24.0, 26.0, 28.0, 30.0,
32.0, 36.0, 40.0, 44.0,
48.0, 52.0, 56.0, 60.0,
64.0, 72.0, 80.0, 88.0,
96.0, 104.0, 112.0, 120.0,
128.0, 144.0, 160.0, 176.0,
192.0, 208.0, 224.0, 240.0,
256.0, 288.0, 320.0, 352.0,
384.0, 416.0, 448.0, std::numeric_limits<float>::quiet_NaN(),
-0.0, -0.001953125, -0.00390625, -0.005859375,
-0.0078125, -0.009765625, -0.01171875, -0.013671875,
-0.015625, -0.017578125, -0.01953125, -0.021484375,
-0.0234375, -0.025390625, -0.02734375, -0.029296875,
-0.03125, -0.03515625, -0.0390625, -0.04296875,
-0.046875, -0.05078125, -0.0546875, -0.05859375,
-0.0625, -0.0703125, -0.078125, -0.0859375,
-0.09375, -0.1015625, -0.109375, -0.1171875,
-0.125, -0.140625, -0.15625, -0.171875,
-0.1875, -0.203125, -0.21875, -0.234375,
-0.25, -0.28125, -0.3125, -0.34375,
-0.375, -0.40625, -0.4375, -0.46875,
-0.5, -0.5625, -0.625, -0.6875,
-0.75, -0.8125, -0.875, -0.9375,
-1.0, -1.125, -1.25, -1.375,
-1.5, -1.625, -1.75, -1.875,
-2.0, -2.25, -2.5, -2.75,
-3.0, -3.25, -3.5, -3.75,
-4.0, -4.5, -5.0, -5.5,
-6.0, -6.5, -7.0, -7.5,
-8.0, -9.0, -10.0, -11.0,
-12.0, -13.0, -14.0, -15.0,
-16.0, -18.0, -20.0, -22.0,
-24.0, -26.0, -28.0, -30.0,
-32.0, -36.0, -40.0, -44.0,
-48.0, -52.0, -56.0, -60.0,
-64.0, -72.0, -80.0, -88.0,
-96.0, -104.0, -112.0, -120.0,
-128.0, -144.0, -160.0, -176.0,
-192.0, -208.0, -224.0, -240.0,
-256.0, -288.0, -320.0, -352.0,
-384.0, -416.0, -448.0, std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e4m3fn fp8_val(bit_val, migraphx_fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx_fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx_fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fn>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx_fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx_fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_max_eq_lowest) {}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -176,25 +176,17 @@ TEST_CASE(test_nan_2) ...@@ -176,25 +176,17 @@ TEST_CASE(test_nan_2)
TEST_CASE(test_infinity_1) TEST_CASE(test_infinity_1)
{ {
float finf = std::numeric_limits<float>::infinity(); float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz // no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); migraphx_fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
TEST_CASE(test_infinity_2) TEST_CASE(test_infinity_2)
{
// no inf in fp8e4m3fnuz, it gets converted to NaNs
migraphx_fp8::fp8e4m3fnuz fp8_nan(std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::infinity());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_3)
{ {
// neg inf // neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity(); float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz // no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); migraphx_fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment