Unverified Commit 7f93a818 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Ref implementation of FP8 (#2438)

Handles all 4 Fp8 dtypes listed here : https://onnx.ai/onnx/technical/float8.html
Follows saturation/clipping logic from table there as well : https://onnx.ai/onnx/technical/float8.html#cast
Only adding fp8e4m3fnuz in MIGraphX IR for now.
parent 5488b443
......@@ -44,7 +44,8 @@
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// shouldn't be used without having enough tests.
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// NOLINTNEXTLINE
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
enum class rounding_mode
{
standard, // standard rounding is doing RNE -- round to nearest even
stochastic
};
enum class f8_type
{
bf8 = 0, // s1e5m2
fp8 = 1 // s1e4m3
};
template <typename T, bool FNUZ = true>
class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data = 0x00;
// default constructor
constexpr float8() = default;
// default copy constructor
constexpr float8(const float8& y) = default;
struct from_bits_t
{
};
static constexpr from_bits_t from_bits() { return from_bits_t(); }
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
explicit constexpr float8(
float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
}
inline constexpr operator float() const
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
inline constexpr bool is_zero() const
{
if constexpr(FNUZ)
{
return data == 0x00;
}
else
{
return (data == 0x00) or (data == 0x80);
}
}
inline constexpr bool is_nan() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
}
else
{
return (data == 0x7F) or (data == 0xFF);
}
}
}
inline constexpr bool is_inf() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7C) or (data == 0xFC);
}
else
{
// no infinities in e4m3fn, represent them as NaNs
return (data == 0x7F) or (data == 0xFF);
}
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& operator unary_op(const float8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
}
MIGRAPHX_FP8_UNARY_OP(*=, *)
MIGRAPHX_FP8_UNARY_OP(-=, -)
MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /)
inline constexpr float8& operator=(const float8& rhs) = default;
inline constexpr float8& operator=(float8&& rhs) noexcept = default;
inline constexpr float8& operator=(float rhs)
{
*this = static_cast<float8>(rhs);
return *this;
}
inline constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
}
inline constexpr bool operator>(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we > them;
}
};
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
/*
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats for binary ops
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
*/
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
template <>
class numeric_limits<fp8e4m3fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT
{
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT
static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
// 7C and FC both are infinity
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
};
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// =================================================================================================
// define numeric limits for the new data type
// NOLINTBEGIN
namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// NOLINTEND
// =================================================================================================
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
namespace impl {
// NOLINTBEGIN
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan, bool Clip>
constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
{
constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now
constexpr bool is_half = false;
static_assert(Wm + We == 7, "Wm+We==7");
static_assert(is_float or is_half, "Only float can be cast to f8");
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(f_x);
else
x = migraphx::bit_cast<uint16_t>(f_x);
uint32_t head = 0;
uint32_t mantissa = 0;
int exponent = 0;
uint32_t bias = 0;
uint32_t sign = 0;
if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
}
else
{
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm);
uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1));
// Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t signed_max = signed_all_ones;
if(not NegativeZeroNan)
signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1);
// Deal with inf and NaNs
if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs
{
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
return 0x80;
}
else
{
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t nan_mantissa = 1;
for(auto i = 1; i < Wm; ++i)
{
nan_mantissa |= (nan_mantissa << 1);
}
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
{
// infinity
if(mantissa == 0)
{
if(sign == 0)
return (Wm == 2) ? 0x7B : 0x7E;
else
return (Wm == 2) ? 0xFB : 0xFE;
}
else // NaNs
return signed_inf + nan_mantissa;
}
}
// handle positive zero
if(x == 0)
return 0;
// handle negative zero
else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{
return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
}
/* First need to check if it is normal or denorm as there is a difference of implict 1
Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
RNE, no need to add rng. Then probably need to check whether there is carry and adjust
exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/
int act_exponent = 0;
int f8_exponent = 0;
int exponent_diff = 0;
if(exponent == 0 and mantissa != 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = 1 - bias;
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1u << mfmt); // Add the implicit 1 into mantissa
}
// need to know whether the number is right in the middle of two adjacent fp8 numbers. use max
// value of 31 to avoid undefined behaviour
bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) ==
(1u << std::min(31u, mfmt - Wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= std::min(31u, uint32_t(exponent_diff));
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1u << (mfmt - Wm)) - 1;
bool odd =
mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1
/*
This part is doing rounding by adding mantissa part that is going to get dropped.
e.g. if the dropped part for less than 0.5 than it would round down.
if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained
mantissa.
For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and
`xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part.
For the odd case :
this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained
part making it RNE.
For the even case : this will add xy0:10000000 + 000:01111111 which would
round down and keep number Even
*/
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0 and ((1 << mfmt) & mantissa))
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
else if((1 << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
}
mantissa >>= (mfmt - Wm);
// above range: quantize to maximum possible float of the same sign
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2);
if(f8_exponent > max_exp)
{
if(Clip)
return signed_max;
else
{
// https://onnx.ai/onnx/technical/float8.html#cast
if(NegativeZeroNan)
return 0x80;
else
return (Wm == 2) ? signed_inf : signed_all_ones;
}
}
if(f8_exponent == 0 and mantissa == 0)
return NegativeZeroNan ? 0 : (sign << 7);
mantissa &= (1 << Wm) - 1;
return (sign << 7) | (f8_exponent << Wm) | mantissa;
}
// NOLINTEND
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan>
constexpr T cast_from_f8(uint8_t x)
{
// half is not supported for now
constexpr bool is_half = false;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_float or is_half, "Only float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
// NOLINTNEXTLINE
T f_inf, f_neg_inf, f_nan, f_neg0;
if constexpr(is_float)
{
const uint32_t if_inf = 0x7F800000;
const uint32_t if_neg_inf = 0xFF800000;
const uint32_t if_nan = 0x7F800001;
const uint32_t if_neg0 = 0x80000000;
f_inf = migraphx::bit_cast<float>(if_inf);
f_neg_inf = migraphx::bit_cast<float>(if_neg_inf);
f_nan = migraphx::bit_cast<float>(if_nan);
f_neg0 = migraphx::bit_cast<float>(if_neg0);
}
if(x == 0)
return 0;
uint32_t sign = x >> 7; // NOLINT
uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT
int exponent = (x & 0x7F) >> Wm; // NOLINT
if(NegativeZeroNan)
{
if(x == 0x80)
return f_nan;
}
else
{
if(x == 0x80)
return f_neg0;
if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT
return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan;
else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return f_nan;
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - Wm);
mantissa <<= sh; // NOLINT
exponent += 1 - sh;
mantissa &= ((1 << Wm) - 1); // NOLINT
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - Wm; // NOLINT
// subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo; // NOLINT
mantissa >>= 1 - exponent; // NOLINT
exponent = 0;
}
if(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT
else
retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT
return migraphx::bit_cast<T>(retval);
}
} // namespace impl
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
......@@ -27,6 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,18 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::half>
{
......
......@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
......@@ -28,25 +28,35 @@
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
......
......@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
......@@ -144,6 +144,18 @@ struct npy_format_descriptor<half>
static constexpr auto name() { return _("half"); }
};
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
// TODO: need to figure out correct encoding
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};
} // namespace detail
} // namespace pybind11
......
......@@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
} // namespace device
} // namespace gpu
......
......@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......
......@@ -150,6 +150,7 @@ function(test_headers PREFIX)
list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif()
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp)
foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
......@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
template <class T, class U>
void test_equality()
{
auto x1 = T(0.1);
auto x1 = T(0.125);
auto x2 = U(0.0);
auto x3 = U(1.0);
EXPECT(test_float_equal(x1, x1));
......@@ -71,8 +72,12 @@ void test_equality()
TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::fp8::fp8e4m3fnuz, int>);
template <class T, class U>
void test_limits()
......@@ -110,8 +115,13 @@ void test_limits()
TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx::fp8::fp8e4m3fnuz, migraphx::half>);
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
TEST_CASE_REGISTER(test_limits<long, int>);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fn_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0, 0.001953125, 0.00390625, 0.005859375,
0.0078125, 0.009765625, 0.01171875, 0.013671875,
0.015625, 0.017578125, 0.01953125, 0.021484375,
0.0234375, 0.025390625, 0.02734375, 0.029296875,
0.03125, 0.03515625, 0.0390625, 0.04296875,
0.046875, 0.05078125, 0.0546875, 0.05859375,
0.0625, 0.0703125, 0.078125, 0.0859375,
0.09375, 0.1015625, 0.109375, 0.1171875,
0.125, 0.140625, 0.15625, 0.171875,
0.1875, 0.203125, 0.21875, 0.234375,
0.25, 0.28125, 0.3125, 0.34375,
0.375, 0.40625, 0.4375, 0.46875,
0.5, 0.5625, 0.625, 0.6875,
0.75, 0.8125, 0.875, 0.9375,
1.0, 1.125, 1.25, 1.375,
1.5, 1.625, 1.75, 1.875,
2.0, 2.25, 2.5, 2.75,
3.0, 3.25, 3.5, 3.75,
4.0, 4.5, 5.0, 5.5,
6.0, 6.5, 7.0, 7.5,
8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0,
16.0, 18.0, 20.0, 22.0,
24.0, 26.0, 28.0, 30.0,
32.0, 36.0, 40.0, 44.0,
48.0, 52.0, 56.0, 60.0,
64.0, 72.0, 80.0, 88.0,
96.0, 104.0, 112.0, 120.0,
128.0, 144.0, 160.0, 176.0,
192.0, 208.0, 224.0, 240.0,
256.0, 288.0, 320.0, 352.0,
384.0, 416.0, 448.0, std::numeric_limits<float>::quiet_NaN(),
-0.0, -0.001953125, -0.00390625, -0.005859375,
-0.0078125, -0.009765625, -0.01171875, -0.013671875,
-0.015625, -0.017578125, -0.01953125, -0.021484375,
-0.0234375, -0.025390625, -0.02734375, -0.029296875,
-0.03125, -0.03515625, -0.0390625, -0.04296875,
-0.046875, -0.05078125, -0.0546875, -0.05859375,
-0.0625, -0.0703125, -0.078125, -0.0859375,
-0.09375, -0.1015625, -0.109375, -0.1171875,
-0.125, -0.140625, -0.15625, -0.171875,
-0.1875, -0.203125, -0.21875, -0.234375,
-0.25, -0.28125, -0.3125, -0.34375,
-0.375, -0.40625, -0.4375, -0.46875,
-0.5, -0.5625, -0.625, -0.6875,
-0.75, -0.8125, -0.875, -0.9375,
-1.0, -1.125, -1.25, -1.375,
-1.5, -1.625, -1.75, -1.875,
-2.0, -2.25, -2.5, -2.75,
-3.0, -3.25, -3.5, -3.75,
-4.0, -4.5, -5.0, -5.5,
-6.0, -6.5, -7.0, -7.5,
-8.0, -9.0, -10.0, -11.0,
-12.0, -13.0, -14.0, -15.0,
-16.0, -18.0, -20.0, -22.0,
-24.0, -26.0, -28.0, -30.0,
-32.0, -36.0, -40.0, -44.0,
-48.0, -52.0, -56.0, -60.0,
-64.0, -72.0, -80.0, -88.0,
-96.0, -104.0, -112.0, -120.0,
-128.0, -144.0, -160.0, -176.0,
-192.0, -208.0, -224.0, -240.0,
-256.0, -288.0, -320.0, -352.0,
-384.0, -416.0, -448.0, std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{{512, 0x7e}, {-512, 0xfe}, {448, 0x7e}, {-448, 0xfe},
{256, 0x78}, {-256, 0xf8}, {240, 0x77}, {-240, 0xf7},
{1e-07, 0x0}, {1e+07, 0x7e}, {1, 0x38}, {-1, 0xb8},
{0.1, 0x1d}, {0.11, 0x1e}, {0.111, 0x1e}, {0.1111, 0x1e},
{-0.1, 0x9d}, {-0.11, 0x9e}, {-0.111, 0x9e}, {-0.1111, 0x9e},
{0.2, 0x25}, {2, 0x40}, {20, 0x5a}, {200, 0x74},
{-0.2, 0xa5}, {-2, 0xc0}, {-20, 0xda}, {-200, 0xf4},
{0.5, 0x30}, {-0.5, 0xb0}, {1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.0078125, 0x4}, {-0.0078125, 0x84}, {0.000976562, 0x0}, {-0.000976562, 0x80},
{0.000488281, 0x0}, {-0.000488281, 0x80}}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fn(sample.first),
migraphx::fp8::fp8e4m3fn(sample.second, migraphx::fp8::fp8e4m3fn::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx::fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx::fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx::fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fn(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fn>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
auto c = migraphx::fp8::fp8e4m3fn(0.0);
auto d = migraphx::fp8::fp8e4m3fn(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fn(10.0);
auto f = migraphx::fp8::fp8e4m3fn(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fnuz_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0f, 0.0009765625f, 0.001953125f,
0.0029296875f, 0.00390625f, 0.0048828125f,
0.005859375f, 0.0068359375f, 0.0078125f,
0.0087890625f, 0.009765625f, 0.0107421875f,
0.01171875f, 0.0126953125f, 0.013671875f,
0.0146484375f, 0.015625f, 0.017578125f,
0.01953125f, 0.021484375f, 0.0234375f,
0.025390625f, 0.02734375f, 0.029296875f,
0.03125f, 0.03515625f, 0.0390625f,
0.04296875f, 0.046875f, 0.05078125f,
0.0546875f, 0.05859375f, 0.0625f,
0.0703125f, 0.078125f, 0.0859375f,
0.09375f, 0.1015625f, 0.109375f,
0.1171875f, 0.125f, 0.140625f,
0.15625f, 0.171875f, 0.1875f,
0.203125f, 0.21875f, 0.234375f,
0.25f, 0.28125f, 0.3125f,
0.34375f, 0.375f, 0.40625f,
0.4375f, 0.46875f, 0.5f,
0.5625f, 0.625f, 0.6875f,
0.75f, 0.8125f, 0.875f,
0.9375f, 1.0f, 1.125f,
1.25f, 1.375f, 1.5f,
1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f,
2.75f, 3.0f, 3.25f,
3.5f, 3.75f, 4.0f,
4.5f, 5.0f, 5.5f,
6.0f, 6.5f, 7.0f,
7.5f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f,
13.0f, 14.0f, 15.0f,
16.0f, 18.0f, 20.0f,
22.0f, 24.0f, 26.0f,
28.0f, 30.0f, 32.0f,
36.0f, 40.0f, 44.0f,
48.0f, 52.0f, 56.0f,
60.0f, 64.0f, 72.0f,
80.0f, 88.0f, 96.0f,
104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f,
176.0f, 192.0f, 208.0f,
224.0f, 240.0f, std::numeric_limits<float>::quiet_NaN(),
-0.0009765625f, -0.001953125f, -0.0029296875f,
-0.00390625f, -0.0048828125f, -0.005859375f,
-0.0068359375f, -0.0078125f, -0.0087890625f,
-0.009765625f, -0.0107421875f, -0.01171875f,
-0.0126953125f, -0.013671875f, -0.0146484375f,
-0.015625f, -0.017578125f, -0.01953125f,
-0.021484375f, -0.0234375f, -0.025390625f,
-0.02734375f, -0.029296875f, -0.03125f,
-0.03515625f, -0.0390625f, -0.04296875f,
-0.046875f, -0.05078125f, -0.0546875f,
-0.05859375f, -0.0625f, -0.0703125f,
-0.078125f, -0.0859375f, -0.09375f,
-0.1015625f, -0.109375f, -0.1171875f,
-0.125f, -0.140625f, -0.15625f,
-0.171875f, -0.1875f, -0.203125f,
-0.21875f, -0.234375f, -0.25f,
-0.28125f, -0.3125f, -0.34375f,
-0.375f, -0.40625f, -0.4375f,
-0.46875f, -0.5f, -0.5625f,
-0.625f, -0.6875f, -0.75f,
-0.8125f, -0.875f, -0.9375f,
-1.0f, -1.125f, -1.25f,
-1.375f, -1.5f, -1.625f,
-1.75f, -1.875f, -2.0f,
-2.25f, -2.5f, -2.75f,
-3.0f, -3.25f, -3.5f,
-3.75f, -4.0f, -4.5f,
-5.0f, -5.5f, -6.0f,
-6.5f, -7.0f, -7.5f,
-8.0f, -9.0f, -10.0f,
-11.0f, -12.0f, -13.0f,
-14.0f, -15.0f, -16.0f,
-18.0f, -20.0f, -22.0f,
-24.0f, -26.0f, -28.0f,
-30.0f, -32.0f, -36.0f,
-40.0f, -44.0f, -48.0f,
-52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f,
-88.0f, -96.0f, -104.0f,
-112.0f, -120.0f, -128.0f,
-144.0f, -160.0f, -176.0f,
-192.0f, -208.0f, -224.0f,
-240.0f,
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fnuz_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {{256, 0x7f}, {-256, 0xff},
{240, 0x7f}, {-240, 0xff},
{1e-07, 0x0}, {1e+07, 0x7f},
{1, 0x40}, {-1, 0xc0},
{0.1, 0x25}, {0.11, 0x26},
{0.111, 0x26}, {0.1111, 0x26},
{-0.1, 0xa5}, {-0.11, 0xa6},
{-0.111, 0xa6}, {-0.1111, 0xa6},
{0.2, 0x2d}, {2, 0x48},
{20, 0x62}, {200, 0x7c},
{-0.2, 0xad}, {-2, 0xc8},
{-20, 0xe2}, {-200, 0xfc},
{0.5, 0x38}, {-0.5, 0xb8},
{1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.00390625, 0x4}, {-0.00390625, 0x84},
{0.00195312, 0x2}, {-0.00195312, 0x82},
{0.000976562, 0x1}, {-0.000976562, 0x81},
{0.000488281, 0x0}, {-0.000488281, 0x0}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fnuz(sample.first),
migraphx::fp8::fp8e4m3fnuz(sample.second, migraphx::fp8::fp8e4m3fnuz::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fnuz(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
auto c = migraphx::fp8::fp8e4m3fnuz(0.0);
auto d = migraphx::fp8::fp8e4m3fnuz(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fnuz(10.0);
auto f = migraphx::fp8::fp8e4m3fnuz(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
#include <sstream>
float fp8e5m2_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0,
1.52587890625e-05,
3.0517578125e-05,
4.57763671875e-05,
6.103515625e-05,
7.62939453125e-05,
9.1552734375e-05,
0.0001068115234375,
0.0001220703125,
0.000152587890625,
0.00018310546875,
0.000213623046875,
0.000244140625,
0.00030517578125,
0.0003662109375,
0.00042724609375,
0.00048828125,
0.0006103515625,
0.000732421875,
0.0008544921875,
0.0009765625,
0.001220703125,
0.00146484375,
0.001708984375,
0.001953125,
0.00244140625,
0.0029296875,
0.00341796875,
0.00390625,
0.0048828125,
0.005859375,
0.0068359375,
0.0078125,
0.009765625,
0.01171875,
0.013671875,
0.015625,
0.01953125,
0.0234375,
0.02734375,
0.03125,
0.0390625,
0.046875,
0.0546875,
0.0625,
0.078125,
0.09375,
0.109375,
0.125,
0.15625,
0.1875,
0.21875,
0.25,
0.3125,
0.375,
0.4375,
0.5,
0.625,
0.75,
0.875,
1.0,
1.25,
1.5,
1.75,
2.0,
2.5,
3.0,
3.5,
4.0,
5.0,
6.0,
7.0,
8.0,
10.0,
12.0,
14.0,
16.0,
20.0,
24.0,
28.0,
32.0,
40.0,
48.0,
56.0,
64.0,
80.0,
96.0,
112.0,
128.0,
160.0,
192.0,
224.0,
256.0,
320.0,
384.0,
448.0,
512.0,
640.0,
768.0,
896.0,
1024.0,
1280.0,
1536.0,
1792.0,
2048.0,
2560.0,
3072.0,
3584.0,
4096.0,
5120.0,
6144.0,
7168.0,
8192.0,
10240.0,
12288.0,
14336.0,
16384.0,
20480.0,
24576.0,
28672.0,
32768.0,
40960.0,
49152.0,
57344.0,
std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.0,
-1.52587890625e-05,
-3.0517578125e-05,
-4.57763671875e-05,
-6.103515625e-05,
-7.62939453125e-05,
-9.1552734375e-05,
-0.0001068115234375,
-0.0001220703125,
-0.000152587890625,
-0.00018310546875,
-0.000213623046875,
-0.000244140625,
-0.00030517578125,
-0.0003662109375,
-0.00042724609375,
-0.00048828125,
-0.0006103515625,
-0.000732421875,
-0.0008544921875,
-0.0009765625,
-0.001220703125,
-0.00146484375,
-0.001708984375,
-0.001953125,
-0.00244140625,
-0.0029296875,
-0.00341796875,
-0.00390625,
-0.0048828125,
-0.005859375,
-0.0068359375,
-0.0078125,
-0.009765625,
-0.01171875,
-0.013671875,
-0.015625,
-0.01953125,
-0.0234375,
-0.02734375,
-0.03125,
-0.0390625,
-0.046875,
-0.0546875,
-0.0625,
-0.078125,
-0.09375,
-0.109375,
-0.125,
-0.15625,
-0.1875,
-0.21875,
-0.25,
-0.3125,
-0.375,
-0.4375,
-0.5,
-0.625,
-0.75,
-0.875,
-1.0,
-1.25,
-1.5,
-1.75,
-2.0,
-2.5,
-3.0,
-3.5,
-4.0,
-5.0,
-6.0,
-7.0,
-8.0,
-10.0,
-12.0,
-14.0,
-16.0,
-20.0,
-24.0,
-28.0,
-32.0,
-40.0,
-48.0,
-56.0,
-64.0,
-80.0,
-96.0,
-112.0,
-128.0,
-160.0,
-192.0,
-224.0,
-256.0,
-320.0,
-384.0,
-448.0,
-512.0,
-640.0,
-768.0,
-896.0,
-1024.0,
-1280.0,
-1536.0,
-1792.0,
-2048.0,
-2560.0,
-3072.0,
-3584.0,
-4096.0,
-5120.0,
-6144.0,
-7168.0,
-8192.0,
-10240.0,
-12288.0,
-14336.0,
-16384.0,
-20480.0,
-24576.0,
-28672.0,
-32768.0,
-40960.0,
-49152.0,
-57344.0,
-1.0f * std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val)))
{
return true;
}
else if(std::isinf(float(fp8_val)) and std::isinf(fp8e5m2_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e5m2_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{-60000, 0xfb},
{-57344, 0xfb},
{-448, 0xdf},
{-256, 0xdc},
{-240, 0xdc},
{-200, 0xda},
{-20, 0xcd},
{-2, 0xc0},
{-1, 0xbc},
{-0.5, 0xb8},
{-0.2, 0xb2},
{-0.1111, 0xaf},
{-0.111, 0xaf},
{-0.11, 0xaf},
{-0.1, 0xae},
{6.10351e-05, 0x4},
{-6.10351e-05, 0x84},
{3.05176e-05, 0x2},
{-3.05176e-05, 0x82},
{1.52588e-05, 0x1},
{-1.52588e-05, 0x81},
{7.62939e-06, 0x0},
{-7.62939e-06, 0x80},
{0.1, 0x2e},
{0.11, 0x2f},
{0.111, 0x2f},
{0.1111, 0x2f},
{0.2, 0x32},
{0.5, 0x38},
{1, 0x3c},
{2, 0x40},
{20, 0x4d},
{200, 0x5a},
{240, 0x5c},
{256, 0x5c},
{448, 0x5f},
{57344, 0x7b},
{60000, 0x7b},
{1e+07, 0x7b},
};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e5m2(sample.first),
migraphx::fp8::fp8e5m2(sample.second, migraphx::fp8::fp8e5m2::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e5m2 fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e5m2
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e5m2 fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
// float infinity should get clipped to max
float finf = std::numeric_limits<float>::infinity();
migraphx::fp8::fp8e5m2 fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2, it gets clipped to lowest
migraphx::fp8::fp8e5m2 fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max();
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e5m2 fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest();
migraphx::fp8::fp8e5m2 fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e5m2(std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN())));
EXPECT(not std::isfinite(std::numeric_limits<migraphx::fp8::fp8e5m2>::infinity()));
// -1.0 * inf is float(-inf) which with clipping/saturation gets converted into fp8::lowest()
EXPECT(std::isfinite(
migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::infinity())));
EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits())));
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e5m2fnuz_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0,
7.62939453125e-06,
1.52587890625e-05,
2.288818359375e-05,
3.0517578125e-05,
3.814697265625e-05,
4.57763671875e-05,
5.340576171875e-05,
6.103515625e-05,
7.62939453125e-05,
9.1552734375e-05,
0.0001068115234375,
0.0001220703125,
0.000152587890625,
0.00018310546875,
0.000213623046875,
0.000244140625,
0.00030517578125,
0.0003662109375,
0.00042724609375,
0.00048828125,
0.0006103515625,
0.000732421875,
0.0008544921875,
0.0009765625,
0.001220703125,
0.00146484375,
0.001708984375,
0.001953125,
0.00244140625,
0.0029296875,
0.00341796875,
0.00390625,
0.0048828125,
0.005859375,
0.0068359375,
0.0078125,
0.009765625,
0.01171875,
0.013671875,
0.015625,
0.01953125,
0.0234375,
0.02734375,
0.03125,
0.0390625,
0.046875,
0.0546875,
0.0625,
0.078125,
0.09375,
0.109375,
0.125,
0.15625,
0.1875,
0.21875,
0.25,
0.3125,
0.375,
0.4375,
0.5,
0.625,
0.75,
0.875,
1.0,
1.25,
1.5,
1.75,
2.0,
2.5,
3.0,
3.5,
4.0,
5.0,
6.0,
7.0,
8.0,
10.0,
12.0,
14.0,
16.0,
20.0,
24.0,
28.0,
32.0,
40.0,
48.0,
56.0,
64.0,
80.0,
96.0,
112.0,
128.0,
160.0,
192.0,
224.0,
256.0,
320.0,
384.0,
448.0,
512.0,
640.0,
768.0,
896.0,
1024.0,
1280.0,
1536.0,
1792.0,
2048.0,
2560.0,
3072.0,
3584.0,
4096.0,
5120.0,
6144.0,
7168.0,
8192.0,
10240.0,
12288.0,
14336.0,
16384.0,
20480.0,
24576.0,
28672.0,
32768.0,
40960.0,
49152.0,
57344.0,
std::numeric_limits<float>::quiet_NaN(),
-7.62939453125e-06,
-1.52587890625e-05,
-2.288818359375e-05,
-3.0517578125e-05,
-3.814697265625e-05,
-4.57763671875e-05,
-5.340576171875e-05,
-6.103515625e-05,
-7.62939453125e-05,
-9.1552734375e-05,
-0.0001068115234375,
-0.0001220703125,
-0.000152587890625,
-0.00018310546875,
-0.000213623046875,
-0.000244140625,
-0.00030517578125,
-0.0003662109375,
-0.00042724609375,
-0.00048828125,
-0.0006103515625,
-0.000732421875,
-0.0008544921875,
-0.0009765625,
-0.001220703125,
-0.00146484375,
-0.001708984375,
-0.001953125,
-0.00244140625,
-0.0029296875,
-0.00341796875,
-0.00390625,
-0.0048828125,
-0.005859375,
-0.0068359375,
-0.0078125,
-0.009765625,
-0.01171875,
-0.013671875,
-0.015625,
-0.01953125,
-0.0234375,
-0.02734375,
-0.03125,
-0.0390625,
-0.046875,
-0.0546875,
-0.0625,
-0.078125,
-0.09375,
-0.109375,
-0.125,
-0.15625,
-0.1875,
-0.21875,
-0.25,
-0.3125,
-0.375,
-0.4375,
-0.5,
-0.625,
-0.75,
-0.875,
-1.0,
-1.25,
-1.5,
-1.75,
-2.0,
-2.5,
-3.0,
-3.5,
-4.0,
-5.0,
-6.0,
-7.0,
-8.0,
-10.0,
-12.0,
-14.0,
-16.0,
-20.0,
-24.0,
-28.0,
-32.0,
-40.0,
-48.0,
-56.0,
-64.0,
-80.0,
-96.0,
-112.0,
-128.0,
-160.0,
-192.0,
-224.0,
-256.0,
-320.0,
-384.0,
-448.0,
-512.0,
-640.0,
-768.0,
-896.0,
-1024.0,
-1280.0,
-1536.0,
-1792.0,
-2048.0,
-2560.0,
-3072.0,
-3584.0,
-4096.0,
-5120.0,
-6144.0,
-7168.0,
-8192.0,
-10240.0,
-12288.0,
-14336.0,
-16384.0,
-20480.0,
-24576.0,
-28672.0,
-32768.0,
-40960.0,
-49152.0,
-57344.0,
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e5m2fnuz_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{57344, 0x7f}, {-57344, 0xff}, {60000, 0x7f}, {-60000, 0xff},
{448, 0x63}, {-448, 0xe3}, {256, 0x60}, {-256, 0xe0},
{240, 0x60}, {-240, 0xe0}, {3.05176e-05, 0x4}, {-3.05176e-05, 0x84},
{1.52588e-05, 0x2}, {-1.52588e-05, 0x82}, {7.62939e-06, 0x1}, {-7.62939e-06, 0x81},
{3.81469e-06, 0x0}, {-3.81469e-06, 0x0}, {1e+07, 0x7f}, {1, 0x40},
{-1, 0xc0}, {0.1, 0x32}, {0.11, 0x33}, {0.111, 0x33},
{0.1111, 0x33}, {-0.1, 0xb2}, {-0.11, 0xb3}, {-0.111, 0xb3},
{-0.1111, 0xb3}, {0.2, 0x36}, {2, 0x44}, {20, 0x51},
{200, 0x5e}, {-0.2, 0xb6}, {-2, 0xc4}, {-20, 0xd1},
{-200, 0xde}, {0.5, 0x3c}, {-0.5, 0xbc}, {1.17549e-38, 0x0},
{1.4013e-45, 0x0},
};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e5m2fnuz(sample.first),
migraphx::fp8::fp8e5m2fnuz(sample.second, migraphx::fp8::fp8e5m2fnuz::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e5m2fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max();
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest();
migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e5m2fnuz(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
auto b = migraphx::fp8::fp8e5m2fnuz(1.0);
auto c = migraphx::fp8::fp8e5m2fnuz(0.0);
auto d = migraphx::fp8::fp8e5m2fnuz(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2fnuz(10.0);
auto f = migraphx::fp8::fp8e5m2fnuz(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
auto b = migraphx::fp8::fp8e5m2fnuz(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -237,12 +237,12 @@ TEST_CASE(code_object_hip)
std::vector<migraphx::shape> expected_inputs = {input, input};
auto co = migraphx::make_op("gpu::code_object",
{{"code_object", migraphx::value::binary{binaries.front()}},
{"symbol_name", "add_2"},
{"global", input.elements()},
{"local", 1024},
{"expected_inputs", migraphx::to_value(expected_inputs)},
{"output", migraphx::to_value(input)}});
{{"code_object", migraphx::value::binary{binaries.front()}},
{"symbol_name", "add_2"},
{"global", input.elements()},
{"local", 1024},
{"expected_inputs", migraphx::to_value(expected_inputs)},
{"output", migraphx::to_value(input)}});
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -348,7 +348,10 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
......@@ -396,7 +399,10 @@ TEST_CASE(assert_type_min_max)
migraphx::gpu::hip_compile_options options;
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
......
......@@ -44,7 +44,8 @@
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......@@ -70,7 +71,9 @@ typedef enum
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
<% generate_c_header() %>
<%
generate_c_header()
%>
#ifdef __cplusplus
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment