Commit f9542d5b authored by Umang Yadav's avatar Umang Yadav
Browse files

Put numeric_max and numeeric lowest into float8

parent 836e201e
......@@ -33,6 +33,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8_impl.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
namespace fp8 {
......@@ -538,6 +539,24 @@ class numeric_limits<fp8e5m2>
};
} // namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2);
} // namespace migraphx
// =================================================================================================
#if defined(__clang__)
......
......@@ -23,26 +23,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp>
#include <migraphx/kernels/type_traits.hpp>
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
namespace migraphx {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
namespace fp8 {
namespace impl {
......@@ -58,7 +45,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng
static_assert(is_float or is_half, "Only float can be cast to f8");
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
typename migraphx::conditional_t<sizeof(T) == 2, uint16_t, uint32_t> x;
if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(f_x);
......@@ -304,7 +291,7 @@ __device__ constexpr T cast_from_f8(uint8_t x)
else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return f_nan;
}
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
typename migraphx::conditional_t<sizeof(T) == 2, uint16_t, uint32_t> retval;
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
......
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
......
......@@ -26,7 +26,6 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx {
......@@ -231,8 +230,7 @@ constexpr unsigned long int_max(unsigned long n)
template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or
is_same<T, migraphx::fp8::fp8e4m3fnuz>{})>
is_same<T, migraphx::half>{})>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
......@@ -248,9 +246,6 @@ constexpr T numeric_max()
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
// TODO: Do it generically for all fp8 types
else if constexpr(is_same<T, migraphx::fp8::fp8e4m3fnuz>{})
return migraphx::fp8::numeric_limits<T>::max();
else
return 0;
}
......@@ -265,8 +260,6 @@ constexpr T numeric_lowest()
else
return -numeric_max<T>() - 1;
}
else if constexpr(is_same<T, migraphx::fp8::fp8e4m3fnuz>{})
return migraphx::fp8::numeric_limits<T>::lowest();
else
{
return -numeric_max<T>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment