Commit 155a2b17 authored by Umang Yadav's avatar Umang Yadav
Browse files

move FNUZ as template parameter

parent 9bc18287
...@@ -29,10 +29,6 @@ ...@@ -29,10 +29,6 @@
#pragma clang diagnostic ignored "-Wc++20-extensions" #pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__ #endif // __clang__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default // We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
...@@ -73,10 +69,10 @@ enum class f8_type ...@@ -73,10 +69,10 @@ enum class f8_type
fp8 = 1 // s1e4m3 fp8 = 1 // s1e4m3
}; };
template <typename T> template <typename T, bool FNUZ = true>
class numeric_limits; class numeric_limits;
template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8> template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8, bool FNUZ = true>
struct float8 struct float8
{ {
uint8_t data = 0x00; uint8_t data = 0x00;
...@@ -100,11 +96,11 @@ struct float8 ...@@ -100,11 +96,11 @@ struct float8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx_f8_impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx_f8_impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
} }
...@@ -112,11 +108,11 @@ struct float8 ...@@ -112,11 +108,11 @@ struct float8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx_f8_impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl:: data = migraphx_f8_impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping} #endif // rocblas_F8_downcast_clipping}
} }
...@@ -126,16 +122,14 @@ struct float8 ...@@ -126,16 +122,14 @@ struct float8
{ {
if constexpr(T == migraphx_fp8::f8_type::fp8) if constexpr(T == migraphx_fp8::f8_type::fp8)
{ {
return migraphx_f8_impl:: return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} // else } // else
return migraphx_f8_impl::cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>( return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
data);
} }
inline constexpr bool is_zero() const inline constexpr bool is_zero() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x00; return data == 0x00;
} }
...@@ -147,7 +141,7 @@ struct float8 ...@@ -147,7 +141,7 @@ struct float8
inline constexpr bool is_nan() const inline constexpr bool is_nan() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x80; return data == 0x80;
} }
...@@ -170,7 +164,7 @@ struct float8 ...@@ -170,7 +164,7 @@ struct float8
inline constexpr bool is_inf() const inline constexpr bool is_inf() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x80; return data == 0x80;
} }
...@@ -218,7 +212,7 @@ struct float8 ...@@ -218,7 +212,7 @@ struct float8
inline constexpr bool operator==(const float8& rhs) const inline constexpr bool operator==(const float8& rhs) const
{ {
if((rhs.is_zero() and this->is_zero()) or if((rhs.is_zero() and this->is_zero()) or
(fabs(rhs - *this) < migraphx_fp8::numeric_limits<float8<T>>::epsilon())) (fabs(rhs - *this) < migraphx_fp8::numeric_limits<float8<T, FNUZ>>::epsilon()))
return true; return true;
else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false; return false;
...@@ -289,123 +283,90 @@ constexpr T F8_Lowest() ...@@ -289,123 +283,90 @@ constexpr T F8_Lowest()
return T{0xFF, T::from_bits()}; return T{0xFF, T::from_bits()};
} }
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8>; using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
template <> template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>> class numeric_limits<fp8e4m3fnuz>
{ {
public: public:
// TODO :figure out epsilon in Hex to make it constexpr static constexpr fp8e4m3fnuz epsilon()
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> epsilon()
{ {
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>( return fp8e4m3fnuz(0x28, migraphx_fp8::float8<>::from_bits());
0x28, migraphx_fp8::float8<>::from_bits());
} }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> quiet_NaN() static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits());
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> max() static constexpr fp8e4m3fnuz max() { return migraphx_fp8::F8_Max<fp8e4m3fnuz>(); }
{
return migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
}
// TODO figure out Hex value // TODO figure out Hex value
static migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> min() static fp8e4m3fnuz min()
{ {
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>(-1.0f) * return static_cast<fp8e4m3fnuz>(-1.0f) * migraphx_fp8::F8_Max<fp8e4m3fnuz>();
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
} }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> lowest() static constexpr fp8e4m3fnuz lowest() { return migraphx_fp8::F8_Lowest<fp8e4m3fnuz>(); }
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> infinity() static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits());
}
}; };
template <> template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>> class numeric_limits<fp8e5m2fnuz>
{ {
public: public:
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> epsilon() static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
0x34, migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> quiet_NaN() static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); }
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> max() static constexpr fp8e5m2fnuz max()
{ {
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>( return static_cast<fp8e5m2fnuz>(migraphx_fp8::F8_Max<fp8e5m2fnuz>());
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>());
} }
// TODO figure out constexpr value // TODO figure out constexpr value
static migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> min() static fp8e5m2fnuz min()
{ {
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>(float(-1.0f)) * return static_cast<fp8e5m2fnuz>(float(-1.0f)) * migraphx_fp8::F8_Max<fp8e5m2fnuz>();
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>();
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>();
} }
static constexpr fp8e5m2fnuz lowest() { return migraphx_fp8::F8_Lowest<fp8e5m2fnuz>(); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> infinity() static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); }
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
}; };
} // namespace migraphx_fp8 } // namespace migraphx_fp8
// ================================================================================================= // =================================================================================================
// define numeric limits for the new data type // define numeric limits for the new data type
namespace std { namespace std {
inline bool isfinite(migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> x) // NOLINT inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{ {
return x.is_inf(); return x.is_inf();
} }
inline bool isfinite(migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> x) // NOLINT inline bool isfinite(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{ {
return x.is_inf(); return x.is_inf();
} }
inline bool isnan(migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> x) // NOLINT inline bool isnan(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{ {
return x.is_nan(); return x.is_nan();
} }
inline bool isnan(migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> x) // NOLINT inline bool isnan(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{ {
return x.is_nan(); return x.is_nan();
} }
template <> template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>> class numeric_limits<migraphx_fp8::fp8e4m3fnuz>
: public migraphx_fp8::numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>> : public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e4m3fnuz>
{ {
}; };
template <> template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>> class numeric_limits<migraphx_fp8::fp8e5m2fnuz>
: public migraphx_fp8::numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>> : public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e5m2fnuz>
{ {
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment