Commit f9542d5b authored by Umang Yadav's avatar Umang Yadav
Browse files

Put numeric_max and numeeric lowest into float8

parent 836e201e
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8_impl.hpp> #include <migraphx/kernels/float8_impl.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
namespace fp8 { namespace fp8 {
...@@ -538,6 +539,24 @@ class numeric_limits<fp8e5m2> ...@@ -538,6 +539,24 @@ class numeric_limits<fp8e5m2>
}; };
} // namespace fp8 } // namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2);
} // namespace migraphx } // namespace migraphx
// ================================================================================================= // =================================================================================================
#if defined(__clang__) #if defined(__clang__)
......
...@@ -23,26 +23,13 @@ ...@@ -23,26 +23,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp> #include <migraphx/kernels/bit_cast.hpp>
#include <migraphx/kernels/type_traits.hpp>
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier" #pragma clang diagnostic ignored "-Wreserved-identifier"
#endif #endif
namespace migraphx { namespace migraphx {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
namespace fp8 { namespace fp8 {
namespace impl { namespace impl {
...@@ -58,7 +45,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng ...@@ -58,7 +45,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng
static_assert(is_float or is_half, "Only float can be cast to f8"); static_assert(is_float or is_half, "Only float can be cast to f8");
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10; const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x; typename migraphx::conditional_t<sizeof(T) == 2, uint16_t, uint32_t> x;
if constexpr(sizeof(T) == 4) if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(f_x); x = migraphx::bit_cast<uint32_t>(f_x);
...@@ -304,7 +291,7 @@ __device__ constexpr T cast_from_f8(uint8_t x) ...@@ -304,7 +291,7 @@ __device__ constexpr T cast_from_f8(uint8_t x)
else if(Wm == 3 and (x == 0x7F or x == 0xFF)) else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return f_nan; return f_nan;
} }
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename migraphx::conditional_t<sizeof(T) == 2, uint16_t, uint32_t> retval;
const int exp_low_cutoff = const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT (1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP #define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8.hpp>
#include <migraphx/kernels/vec.hpp> #include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
...@@ -231,8 +230,7 @@ constexpr unsigned long int_max(unsigned long n) ...@@ -231,8 +230,7 @@ constexpr unsigned long int_max(unsigned long n)
template <class T, template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or is_same<T, migraphx::half>{})>
is_same<T, migraphx::fp8::fp8e4m3fnuz>{})>
constexpr T numeric_max() constexpr T numeric_max()
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
...@@ -248,9 +246,6 @@ constexpr T numeric_max() ...@@ -248,9 +246,6 @@ constexpr T numeric_max()
return __FLT_MAX__; return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{}) else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__; return __FLT16_MAX__;
// TODO: Do it generically for all fp8 types
else if constexpr(is_same<T, migraphx::fp8::fp8e4m3fnuz>{})
return migraphx::fp8::numeric_limits<T>::max();
else else
return 0; return 0;
} }
...@@ -265,8 +260,6 @@ constexpr T numeric_lowest() ...@@ -265,8 +260,6 @@ constexpr T numeric_lowest()
else else
return -numeric_max<T>() - 1; return -numeric_max<T>() - 1;
} }
else if constexpr(is_same<T, migraphx::fp8::fp8e4m3fnuz>{})
return migraphx::fp8::numeric_limits<T>::lowest();
else else
{ {
return -numeric_max<T>(); return -numeric_max<T>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment