"docs/vscode:/vscode.git/clone" did not exist on "2b981012a6eb27d566f03cf61c06b1ef7a522f27"
Commit ab653aff authored by Umang Yadav's avatar Umang Yadav
Browse files

Review updates

parent 183db78a
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#ifdef __cplusplus #ifdef __cplusplus
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <migraphx/config.hpp>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
...@@ -44,20 +44,12 @@ ...@@ -44,20 +44,12 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <utility> #include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace migraphx_f8_impl { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> namespace fp8 {
constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan>
constexpr T cast_from_f8(uint8_t x);
} // namespace migraphx_f8_impl
#include <migraphx/migraphx_f8_impl.hpp>
namespace migraphx_fp8 {
enum class migraphx_f8_rounding_mode enum class migraphx_f8_rounding_mode
{ {
...@@ -74,7 +66,7 @@ enum class f8_type ...@@ -74,7 +66,7 @@ enum class f8_type
template <typename T, bool FNUZ = true> template <typename T, bool FNUZ = true>
class numeric_limits; class numeric_limits;
template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8, bool FNUZ = true> template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8 struct float8
{ {
uint8_t data = 0x00; uint8_t data = 0x00;
...@@ -90,43 +82,43 @@ struct float8 ...@@ -90,43 +82,43 @@ struct float8
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
explicit constexpr float8(float v, explicit constexpr float8(float v,
migraphx_fp8::migraphx_f8_rounding_mode rm = migraphx::fp8::migraphx_f8_rounding_mode rm =
migraphx_fp8::migraphx_f8_rounding_mode::standard, migraphx::fp8::migraphx_f8_rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
{ {
if constexpr(T == migraphx_fp8::f8_type::fp8) if constexpr(T == migraphx::fp8::f8_type::fp8)
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
} }
else else
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping} #endif // rocblas_F8_downcast_clipping}
} }
} }
inline constexpr operator float() const inline constexpr operator float() const
{ {
if constexpr(T == migraphx_fp8::f8_type::fp8) if constexpr(T == migraphx::fp8::f8_type::fp8)
{ {
return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else } // else
return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
} }
inline constexpr bool is_zero() const inline constexpr bool is_zero() const
...@@ -149,7 +141,7 @@ struct float8 ...@@ -149,7 +141,7 @@ struct float8
} }
else else
{ {
if(T == migraphx_fp8::f8_type::bf8) if(T == migraphx::fp8::f8_type::bf8)
{ {
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF); (data == 0xFE) or (data == 0xFF);
...@@ -169,7 +161,7 @@ struct float8 ...@@ -169,7 +161,7 @@ struct float8
} }
else else
{ {
if(T == migraphx_fp8::f8_type::bf8) if(T == migraphx::fp8::f8_type::bf8)
{ {
return (data == 0x7C) or (data == 0xFC); return (data == 0x7C) or (data == 0xFC);
} }
...@@ -236,26 +228,26 @@ struct float8 ...@@ -236,26 +228,26 @@ struct float8
}; };
// Special operator overloading // Special operator overloading
template <migraphx_fp8::f8_type T> template <migraphx::fp8::f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::float8<T>& rhs) inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>& rhs)
{ {
return os << static_cast<float>(rhs); return os << static_cast<float>(rhs);
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ #define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx_fp8::f8_type T> \ template <migraphx::fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx_fp8::float8<T>& lhs, \ inline constexpr U operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx_fp8::float8<T>& rhs) \ const migraphx::fp8::float8<T>& rhs) \
{ \ { \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \ return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
} }
// TODO: these should return floats // TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::float8<T>) MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::float8<T>) MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::float8<T>) MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::float8<T>) MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8<T>)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding // TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects. // effects.
MIGRAPHX_FP8_BINARY_OP(==, bool) MIGRAPHX_FP8_BINARY_OP(==, bool)
...@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool) ...@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool) MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool) MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx_fp8::f8_type T> template <migraphx::fp8::f8_type T>
inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v) inline migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{ {
v.data = v.data & 0x7f; v.data = v.data & 0x7f;
return v; return v;
} }
// https://onnx.ai/onnx/technical/float8.html // https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>; using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>; using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>; using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>; using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
template <> template <>
class numeric_limits<fp8e4m3fnuz> class numeric_limits<fp8e4m3fnuz>
...@@ -347,37 +339,39 @@ class numeric_limits<fp8e5m2> ...@@ -347,37 +339,39 @@ class numeric_limits<fp8e5m2>
// 7C and FC both are infinity // 7C and FC both are infinity
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
}; };
} // namespace migraphx_fp8 } // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// ================================================================================================= // =================================================================================================
// define numeric limits for the new data type // define numeric limits for the new data type
namespace std { namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ #define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return x.is_inf(); } \ inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \ inline bool isnan(T x) { return x.is_nan(); } \
template <> \ template <> \
class numeric_limits<T> : public migraphx_fp8::numeric_limits<T> \ class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \ { \
}; \ }; \
template <class U> \ template <class U> \
struct common_type<T, U> : std::common_type<float, U> \ struct common_type<T, U> : std::common_type<float, U> \
{ \ { \
}; \ }; \
template <class U> \ template <class U> \
struct common_type<U, T> : std::common_type<float, U> \ struct common_type<U, T> : std::common_type<float, U> \
{ \ { \
}; \ }; \
template <> \ template <> \
struct common_type<T, T> \ struct common_type<T, T> \
{ \ { \
using type = T; \ using type = T; \
}; };
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fn) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fnuz) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2fnuz) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std } // namespace std
// ================================================================================================= // =================================================================================================
......
...@@ -20,49 +20,32 @@ ...@@ -20,49 +20,32 @@
* *
* ************************************************************************ */ * ************************************************************************ */
#ifndef MIGRAPHX_FP8_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_FP8_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <type_traits>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) #include <migraphx/config.hpp>
namespace migraphx_f8_impl { #include <migraphx/bit_cast.hpp>
namespace detail { namespace migraphx {
template <bool B, class T, class F> inline namespace MIGRAPHX_INLINE_NS {
struct conditional namespace fp8 {
{ namespace impl {
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace detail
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
{ {
constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now
constexpr bool is_half = false;
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
static_assert(is_float or is_half, "Only float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x; typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
x = detail::bit_cast<uint32_t>(_x); x = migraphx::bit_cast<uint32_t>(_x);
else else
x = detail::bit_cast<uint16_t>(_x); x = migraphx::bit_cast<uint16_t>(_x);
uint32_t head, mantissa; uint32_t head, mantissa;
int exponent, bias; int exponent, bias;
...@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) ...@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
template <int wm, int we, typename T, bool negative_zero_nan> template <int wm, int we, typename T, bool negative_zero_nan>
constexpr T cast_from_f8(uint8_t x) constexpr T cast_from_f8(uint8_t x)
{ {
constexpr int weo = 8; // half is not supported for now
constexpr int wmo = 23; constexpr bool is_half = false;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_float or is_half, "Only float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0; T fInf, fNegInf, fNaN, fNeg0;
uint32_t ifInf = 0x7F800000;
uint32_t ifNegInf = 0xFF800000;
uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000;
fInf = detail::bit_cast<float>(ifInf); if constexpr(is_float)
fNegInf = detail::bit_cast<float>(ifNegInf); {
fNaN = detail::bit_cast<float>(ifNaN); const uint32_t ifInf = 0x7F800000;
fNeg0 = detail::bit_cast<float>(ifNeg0); const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = migraphx::bit_cast<float>(ifInf);
fNegInf = migraphx::bit_cast<float>(ifNegInf);
fNaN = migraphx::bit_cast<float>(ifNaN);
fNeg0 = migraphx::bit_cast<float>(ifNeg0);
}
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x) ...@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
else if(wm == 3 and (x == 0x7F or x == 0xFF)) else if(wm == 3 and (x == 0x7F or x == 0xFF))
return fNaN; return fNaN;
} }
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
...@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x) ...@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa;
else else
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa;
return detail::bit_cast<T>(retval); return migraphx::bit_cast<T>(retval);
} }
} // namespace migraphx_f8_impl } // namespace impl
#endif // MIGRAPHX_FP8_IMPL_HPP } // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <half/half.hpp> #include <half/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT ...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
}; };
template <> template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half> struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{ {
using type = float; using type = float;
}; };
template <> template <>
struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz> struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{ {
using type = float; using type = float;
}; };
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point); ...@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic); MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed); MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
template <class T, class U>
struct is_same : std::is_same<T, U>
{
};
template <bool B, class T, class U>
struct conditional : std::conditional<B, T, U>
{
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half> ...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
}; };
template <> template <>
struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz> struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{ {
static std::string format() static std::string format()
{ {
......
...@@ -150,7 +150,7 @@ function(test_headers PREFIX) ...@@ -150,7 +150,7 @@ function(test_headers PREFIX)
list(REMOVE_ITEM HEADERS list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp) ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif() endif()
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_f8_impl.hpp) list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp)
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -72,12 +72,12 @@ void test_equality() ...@@ -72,12 +72,12 @@ void test_equality()
TEST_CASE_REGISTER(test_equality<double, float>); TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>); TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>); TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<double, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_equality<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, int>); TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<float, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_equality<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>); TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx_fp8::fp8e4m3fnuz, int>); TEST_CASE_REGISTER(test_equality<migraphx::fp8::fp8e4m3fnuz, int>);
template <class T, class U> template <class T, class U>
void test_limits() void test_limits()
...@@ -115,12 +115,12 @@ void test_limits() ...@@ -115,12 +115,12 @@ void test_limits()
TEST_CASE_REGISTER(test_limits<double, float>); TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>); TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>); TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<double, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_limits<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, int>); TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<float, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_limits<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>); TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx_fp8::fp8e4m3fnuz>); TEST_CASE_REGISTER(test_limits<int, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx_fp8::fp8e4m3fnuz, migraphx::half>); TEST_CASE_REGISTER(test_limits<migraphx::fp8::fp8e4m3fnuz, migraphx::half>);
#ifndef _WIN32 #ifndef _WIN32
// On Windows, types int and long have the same min and max values. // On Windows, types int and long have the same min and max values.
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <cmath> #include <cmath>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256); std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0); std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e4m3fn fp8_val(bit_val, migraphx_fp8::fp8e4m3fn::from_bits()); migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val))) if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{ {
return true; return true;
...@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero) TEST_CASE(test_positive_zero)
{ {
float zero = 0.0; float zero = 0.0;
migraphx_fp8::fp8e4m3fn fp8_zero(zero); migraphx::fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero()); EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero))); EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
} }
...@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero) ...@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE(test_negative_zero) TEST_CASE(test_negative_zero)
{ {
float nzero = -0.0; float nzero = -0.0;
migraphx_fp8::fp8e4m3fn fp8_nzero(nzero); migraphx::fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero()); EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn // negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
...@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero) ...@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan); migraphx::fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
} }
TEST_CASE(test_nan_2) TEST_CASE(test_nan_2)
{ {
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fn>::quiet_NaN(); auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fn::from_bits()); migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
...@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1) ...@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
{ {
float finf = std::numeric_limits<float>::infinity(); float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max() // no inf in fp8e4m3fn, it gets clipped to max()
migraphx_fp8::fp8e4m3fn fp8_max(finf); migraphx::fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
} }
TEST_CASE(test_infinity_2) TEST_CASE(test_infinity_2)
...@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2) ...@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
// neg inf // neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity(); float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest // no inf in fp8e4m3fn, it gets clipped to lowest
migraphx_fp8::fp8e4m3fn fp8_lowest(finf); migraphx::fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest()}); EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest()});
} }
TEST_CASE(test_numeric_max_1) TEST_CASE(test_numeric_max_1)
{ {
float fmax = std::numeric_limits<float>::max(); float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax); migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
} }
TEST_CASE(test_numeric_max_2) TEST_CASE(test_numeric_max_2)
{ {
// gets clipped to max // gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max(); float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax); migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
} }
TEST_CASE(test_numeric_lowest_1) TEST_CASE(test_numeric_lowest_1)
{ {
float flowest = std::numeric_limits<float>::lowest(); float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(flowest); migraphx::fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
} }
TEST_CASE(test_numeric_lowest_2) TEST_CASE(test_numeric_lowest_2)
{ {
// gets clipped to lowest // gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest(); float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(fmin); migraphx::fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
} }
TEST_CASE(test_max_eq_lowest) TEST_CASE(test_max_eq_lowest)
{ {
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest(), EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max())); -1 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max()));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <cmath> #include <cmath>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256); std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0); std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx_fp8::fp8e4m3fnuz::from_bits()); migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val))) if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val)))
{ {
return true; return true;
...@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero) TEST_CASE(test_positive_zero)
{ {
float zero = 0.0; float zero = 0.0;
migraphx_fp8::fp8e4m3fnuz fp8_zero(zero); migraphx::fp8::fp8e4m3fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero()); EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero))); EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
} }
...@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero) ...@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
{ {
float nzero = -0.0; float nzero = -0.0;
float pzero = 0.0; float pzero = 0.0;
migraphx_fp8::fp8e4m3fnuz fp8_nzero(nzero); migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero()); EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero // negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
...@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero) ...@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan); migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
} }
TEST_CASE(test_nan_2) TEST_CASE(test_nan_2)
{ {
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN(); auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits()); migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits());
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
...@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1) ...@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
{ {
float finf = std::numeric_limits<float>::infinity(); float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to Nans // no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
...@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2) ...@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
// neg inf // neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity(); float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to NaNs // no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
...@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2) ...@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE(test_numeric_max_1) TEST_CASE(test_numeric_max_1)
{ {
float fmax = std::numeric_limits<float>::max(); float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
} }
TEST_CASE(test_numeric_max_2) TEST_CASE(test_numeric_max_2)
{ {
// gets clipped to max // gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max(); float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
} }
TEST_CASE(test_numeric_lowest_1) TEST_CASE(test_numeric_lowest_1)
{ {
float flowest = std::numeric_limits<float>::lowest(); float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest); migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
} }
TEST_CASE(test_numeric_lowest_2) TEST_CASE(test_numeric_lowest_2)
{ {
// gets clipped to lowest // gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest(); float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin); migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
} }
TEST_CASE(test_max_eq_lowest) TEST_CASE(test_max_eq_lowest)
{ {
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest(), EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max())); -1 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max()));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <cmath> #include <cmath>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256); std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0); std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e5m2 fp8_val(bit_val, migraphx_fp8::fp8e5m2::from_bits()); migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val))) if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val)))
{ {
return true; return true;
...@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero) TEST_CASE(test_positive_zero)
{ {
float zero = 0.0; float zero = 0.0;
migraphx_fp8::fp8e5m2 fp8_zero(zero); migraphx::fp8::fp8e5m2 fp8_zero(zero);
EXPECT(fp8_zero.is_zero()); EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero))); EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
} }
...@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero) ...@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE(test_negative_zero) TEST_CASE(test_negative_zero)
{ {
float nzero = -0.0; float nzero = -0.0;
migraphx_fp8::fp8e5m2 fp8_nzero(nzero); migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero()); EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e5m2 // negative zero is preserved for fp8e5m2
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
...@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero) ...@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e5m2 fp8_nan(fnan); migraphx::fp8::fp8e5m2 fp8_nan(fnan);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
} }
TEST_CASE(test_nan_2) TEST_CASE(test_nan_2)
{ {
auto fnan = std::numeric_limits<migraphx_fp8::fp8e5m2>::quiet_NaN(); auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
migraphx_fp8::fp8e5m2 fp8_nan(fnan.data, migraphx_fp8::fp8e5m2::from_bits()); migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits());
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
...@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1) ...@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
{ {
// float infinity should get clipped to max // float infinity should get clipped to max
float finf = std::numeric_limits<float>::infinity(); float finf = std::numeric_limits<float>::infinity();
migraphx_fp8::fp8e5m2 fp8_max(finf); migraphx::fp8::fp8e5m2 fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
} }
TEST_CASE(test_infinity_2) TEST_CASE(test_infinity_2)
...@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2) ...@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
// neg inf // neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity(); float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2, it gets clipped to lowest // no inf in fp8e5m2, it gets clipped to lowest
migraphx_fp8::fp8e5m2 fp8_lowest(finf); migraphx::fp8::fp8e5m2 fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest()}); EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest()});
} }
TEST_CASE(test_numeric_max_1) TEST_CASE(test_numeric_max_1)
{ {
float fmax = std::numeric_limits<float>::max(); float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e5m2 fp8_max(fmax); migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
} }
TEST_CASE(test_numeric_max_2) TEST_CASE(test_numeric_max_2)
{ {
// gets clipped to max // gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e5m2>::max(); float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max();
migraphx_fp8::fp8e5m2 fp8_max(fmax); migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
} }
TEST_CASE(test_numeric_lowest_1) TEST_CASE(test_numeric_lowest_1)
{ {
float flowest = std::numeric_limits<float>::lowest(); float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e5m2 fp8_lowest(flowest); migraphx::fp8::fp8e5m2 fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
} }
TEST_CASE(test_numeric_lowest_2) TEST_CASE(test_numeric_lowest_2)
{ {
// gets clipped to lowest // gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest(); float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest();
migraphx_fp8::fp8e5m2 fp8_lowest(fmin); migraphx::fp8::fp8e5m2 fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
} }
TEST_CASE(test_max_eq_lowest) TEST_CASE(test_max_eq_lowest)
{ {
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest(), EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e5m2>::max())); -1 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max()));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <cmath> #include <cmath>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp> #include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256); std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0); std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx_fp8::fp8e5m2fnuz::from_bits()); migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val))) if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val)))
{ {
return true; return true;
...@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float) ...@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero) TEST_CASE(test_positive_zero)
{ {
float zero = 0.0; float zero = 0.0;
migraphx_fp8::fp8e5m2fnuz fp8_zero(zero); migraphx::fp8::fp8e5m2fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero()); EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero))); EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
} }
...@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero) ...@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
{ {
float nzero = -0.0; float nzero = -0.0;
float pzero = 0.0; float pzero = 0.0;
migraphx_fp8::fp8e5m2fnuz fp8_nzero(nzero); migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero()); EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero // negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
...@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero) ...@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1) TEST_CASE(test_nan_1)
{ {
float fnan = std::numeric_limits<float>::quiet_NaN(); float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan); migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
} }
TEST_CASE(test_nan_2) TEST_CASE(test_nan_2)
{ {
auto fnan = std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::quiet_NaN(); auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e5m2fnuz::from_bits()); migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits());
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
...@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1) ...@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
{ {
float finf = std::numeric_limits<float>::infinity(); float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to Nans // no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
...@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2) ...@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
// neg inf // neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity(); float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to NaNs // no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan()); EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan))); EXPECT(std::isnan(float(fp8_nan)));
} }
...@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2) ...@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE(test_numeric_max_1) TEST_CASE(test_numeric_max_1)
{ {
float fmax = std::numeric_limits<float>::max(); float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
} }
TEST_CASE(test_numeric_max_2) TEST_CASE(test_numeric_max_2)
{ {
// gets clipped to max // gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max(); float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max();
migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max()); EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
} }
TEST_CASE(test_numeric_lowest_1) TEST_CASE(test_numeric_lowest_1)
{ {
float flowest = std::numeric_limits<float>::lowest(); float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e5m2fnuz fp8_lowest(flowest); migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
} }
TEST_CASE(test_numeric_lowest_2) TEST_CASE(test_numeric_lowest_2)
{ {
// gets clipped to lowest // gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest(); float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest();
migraphx_fp8::fp8e5m2fnuz fp8_lowest(fmin); migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest()); EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
} }
TEST_CASE(test_max_eq_lowest) TEST_CASE(test_max_eq_lowest)
{ {
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest(), EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max())); -1 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max()));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#ifdef __cplusplus #ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment