Commit ab653aff authored by Umang Yadav's avatar Umang Yadav
Browse files

Review updates

parent 183db78a
......@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <migraphx/config.hpp>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
......@@ -44,20 +44,12 @@
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace migraphx_f8_impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan>
constexpr T cast_from_f8(uint8_t x);
} // namespace migraphx_f8_impl
#include <migraphx/migraphx_f8_impl.hpp>
namespace migraphx_fp8 {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
enum class migraphx_f8_rounding_mode
{
......@@ -74,7 +66,7 @@ enum class f8_type
template <typename T, bool FNUZ = true>
class numeric_limits;
template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8, bool FNUZ = true>
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data = 0x00;
......@@ -90,43 +82,43 @@ struct float8
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
explicit constexpr float8(float v,
migraphx_fp8::migraphx_f8_rounding_mode rm =
migraphx_fp8::migraphx_f8_rounding_mode::standard,
migraphx::fp8::migraphx_f8_rounding_mode rm =
migraphx::fp8::migraphx_f8_rounding_mode::standard,
uint32_t rng = 0)
{
if constexpr(T == migraphx_fp8::f8_type::fp8)
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
}
inline constexpr operator float() const
{
if constexpr(T == migraphx_fp8::f8_type::fp8)
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
inline constexpr bool is_zero() const
......@@ -149,7 +141,7 @@ struct float8
}
else
{
if(T == migraphx_fp8::f8_type::bf8)
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
......@@ -169,7 +161,7 @@ struct float8
}
else
{
if(T == migraphx_fp8::f8_type::bf8)
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7C) or (data == 0xFC);
}
......@@ -236,26 +228,26 @@ struct float8
};
// Special operator overloading
template <migraphx_fp8::f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::float8<T>& rhs)
template <migraphx::fp8::f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>& rhs)
{
return os << static_cast<float>(rhs);
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx_fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx_fp8::float8<T>& lhs, \
const migraphx_fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8<T>)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
......@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
template <migraphx_fp8::f8_type T>
inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v)
template <migraphx::fp8::f8_type T>
inline migraphx::fp8::float8<T> fabs(migraphx::fp8::float8<T> v)
{
v.data = v.data & 0x7f;
return v;
}
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
template <>
class numeric_limits<fp8e4m3fnuz>
......@@ -347,37 +339,39 @@ class numeric_limits<fp8e5m2>
// 7C and FC both are infinity
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
};
} // namespace migraphx_fp8
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// =================================================================================================
// define numeric limits for the new data type
namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx_fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// =================================================================================================
......
......@@ -20,49 +20,32 @@
*
* ************************************************************************ */
#ifndef MIGRAPHX_FP8_IMPL_HPP
#define MIGRAPHX_FP8_IMPL_HPP
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx_f8_impl {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace detail
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <type_traits>
#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
namespace impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
{
constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now
constexpr bool is_half = false;
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_float or is_half, "Only float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4)
x = detail::bit_cast<uint32_t>(_x);
x = migraphx::bit_cast<uint32_t>(_x);
else
x = detail::bit_cast<uint16_t>(_x);
x = migraphx::bit_cast<uint16_t>(_x);
uint32_t head, mantissa;
int exponent, bias;
......@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
template <int wm, int we, typename T, bool negative_zero_nan>
constexpr T cast_from_f8(uint8_t x)
{
constexpr int weo = 8;
constexpr int wmo = 23;
// half is not supported for now
constexpr bool is_half = false;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_float or is_half, "Only float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
uint32_t ifInf = 0x7F800000;
uint32_t ifNegInf = 0xFF800000;
uint32_t ifNaN = 0x7F800001;
uint32_t ifNeg0 = 0x80000000;
fInf = detail::bit_cast<float>(ifInf);
fNegInf = detail::bit_cast<float>(ifNegInf);
fNaN = detail::bit_cast<float>(ifNaN);
fNeg0 = detail::bit_cast<float>(ifNeg0);
if constexpr(is_float)
{
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = migraphx::bit_cast<float>(ifInf);
fNegInf = migraphx::bit_cast<float>(ifNegInf);
fNaN = migraphx::bit_cast<float>(ifNaN);
fNeg0 = migraphx::bit_cast<float>(ifNeg0);
}
if(x == 0)
return 0;
......@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
else if(wm == 3 and (x == 0x7F or x == 0xFF))
return fNaN;
}
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
......@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
retval = (sign << 15) | (exponent << 10) | mantissa;
else
retval = (sign << 31) | (exponent << 23) | mantissa;
return detail::bit_cast<T>(retval);
return migraphx::bit_cast<T>(retval);
}
} // namespace migraphx_f8_impl
#endif // MIGRAPHX_FP8_IMPL_HPP
} // namespace impl
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
......@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
......
......@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
......@@ -28,7 +28,7 @@
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
template <class T, class U>
struct is_same : std::is_same<T, U>
{
};
template <bool B, class T, class U>
struct conditional : std::conditional<B, T, U>
{
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
......
......@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
......@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
template <>
struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
......
......@@ -150,7 +150,7 @@ function(test_headers PREFIX)
list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif()
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_f8_impl.hpp)
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp)
foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
......@@ -22,7 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
......@@ -72,12 +72,12 @@ void test_equality()
TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<double, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<float, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx_fp8::fp8e4m3fnuz, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::fp8::fp8e4m3fnuz, int>);
template <class T, class U>
void test_limits()
......@@ -115,12 +115,12 @@ void test_limits()
TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<double, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<float, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx_fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx_fp8::fp8e4m3fnuz, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx::fp8::fp8e4m3fnuz, migraphx::half>);
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
......
......@@ -23,7 +23,7 @@
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
......@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e4m3fn fp8_val(bit_val, migraphx_fp8::fp8e4m3fn::from_bits());
migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{
return true;
......@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx_fp8::fp8e4m3fn fp8_zero(zero);
migraphx::fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
......@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx_fp8::fp8e4m3fn fp8_nzero(nzero);
migraphx::fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
......@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan);
migraphx::fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fn>::quiet_NaN();
migraphx_fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fn::from_bits());
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
......@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx_fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
migraphx::fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_infinity_2)
......@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx_fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest()});
migraphx::fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max();
migraphx_fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max());
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest());
migraphx::fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest();
migraphx_fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest());
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e4m3fn>::max()));
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -23,7 +23,7 @@
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
......@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx_fp8::fp8e4m3fnuz::from_bits());
migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val)))
{
return true;
......@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx_fp8::fp8e4m3fnuz fp8_zero(zero);
migraphx::fp8::fp8e4m3fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
......@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx_fp8::fp8e4m3fnuz fp8_nzero(nzero);
migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
......@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan);
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits());
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
......@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf);
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
......@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx_fp8::fp8e4m3fnuz fp8_nan(finf);
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
......@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max();
migraphx_fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max());
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest();
migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest());
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e4m3fnuz>::max()));
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -23,7 +23,7 @@
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
......@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e5m2 fp8_val(bit_val, migraphx_fp8::fp8e5m2::from_bits());
migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val)))
{
return true;
......@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx_fp8::fp8e5m2 fp8_zero(zero);
migraphx::fp8::fp8e5m2 fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
......@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx_fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e5m2
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
......@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e5m2 fp8_nan(fnan);
migraphx::fp8::fp8e5m2 fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx_fp8::fp8e5m2>::quiet_NaN();
migraphx_fp8::fp8e5m2 fp8_nan(fnan.data, migraphx_fp8::fp8e5m2::from_bits());
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
......@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
{
// float infinity should get clipped to max
float finf = std::numeric_limits<float>::infinity();
migraphx_fp8::fp8e5m2 fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max());
migraphx::fp8::fp8e5m2 fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_infinity_2)
......@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2, it gets clipped to lowest
migraphx_fp8::fp8e5m2 fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest()});
migraphx::fp8::fp8e5m2 fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max());
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e5m2>::max();
migraphx_fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2>::max());
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max();
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e5m2 fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest());
migraphx::fp8::fp8e5m2 fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest();
migraphx_fp8::fp8e5m2 fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest());
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest();
migraphx::fp8::fp8e5m2 fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e5m2>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e5m2>::max()));
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -23,7 +23,7 @@
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
......@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx_fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx_fp8::fp8e5m2fnuz::from_bits());
migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val)))
{
return true;
......@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx_fp8::fp8e5m2fnuz fp8_zero(zero);
migraphx::fp8::fp8e5m2fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
......@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx_fp8::fp8e5m2fnuz fp8_nzero(nzero);
migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
......@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan);
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::quiet_NaN();
migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e5m2fnuz::from_bits());
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
......@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx_fp8::fp8e5m2fnuz fp8_nan(finf);
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
......@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx_fp8::fp8e5m2fnuz fp8_nan(finf);
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
......@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx_fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max());
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max();
migraphx_fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max());
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max();
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx_fp8::fp8e5m2fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest());
migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest();
migraphx_fp8::fp8e5m2fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest());
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest();
migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::lowest(),
-1 * std::numeric_limits<migraphx_fp8::fp8e5m2fnuz>::max()));
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment