Commit 155a2b17 authored by Umang Yadav's avatar Umang Yadav
Browse files

move FNUZ as template parameter

parent 9bc18287
......@@ -29,10 +29,6 @@
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
......@@ -73,10 +69,10 @@ enum class f8_type
fp8 = 1 // s1e4m3
};
template <typename T>
template <typename T, bool FNUZ = true>
class numeric_limits;
template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8>
template <migraphx_fp8::f8_type T = migraphx_fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data = 0x00;
......@@ -100,11 +96,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
......@@ -112,11 +108,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx_f8_impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
......@@ -126,16 +122,14 @@ struct float8
{
if constexpr(T == migraphx_fp8::f8_type::fp8)
{
return migraphx_f8_impl::
cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx_f8_impl::cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(
data);
return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
inline constexpr bool is_zero() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x00;
}
......@@ -147,7 +141,7 @@ struct float8
inline constexpr bool is_nan() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x80;
}
......@@ -170,7 +164,7 @@ struct float8
inline constexpr bool is_inf() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x80;
}
......@@ -218,7 +212,7 @@ struct float8
inline constexpr bool operator==(const float8& rhs) const
{
if((rhs.is_zero() and this->is_zero()) or
(fabs(rhs - *this) < migraphx_fp8::numeric_limits<float8<T>>::epsilon()))
(fabs(rhs - *this) < migraphx_fp8::numeric_limits<float8<T, FNUZ>>::epsilon()))
return true;
else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
......@@ -289,123 +283,90 @@ constexpr T F8_Lowest()
return T{0xFF, T::from_bits()};
}
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8>;
using fp8e4m3fn = float8<migraphx_fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx_fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx_fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>
class numeric_limits<fp8e4m3fnuz>
{
public:
// TODO :figure out epsilon in Hex to make it constexpr
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> epsilon()
static constexpr fp8e4m3fnuz epsilon()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>(
0x28, migraphx_fp8::float8<>::from_bits());
return fp8e4m3fnuz(0x28, migraphx_fp8::float8<>::from_bits());
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> quiet_NaN()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits());
}
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> max()
{
return migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
}
static constexpr fp8e4m3fnuz max() { return migraphx_fp8::F8_Max<fp8e4m3fnuz>(); }
// TODO figure out Hex value
static migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> min()
static fp8e4m3fnuz min()
{
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>(-1.0f) *
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
return static_cast<fp8e4m3fnuz>(-1.0f) * migraphx_fp8::F8_Max<fp8e4m3fnuz>();
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>();
}
static constexpr fp8e4m3fnuz lowest() { return migraphx_fp8::F8_Lowest<fp8e4m3fnuz>(); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> infinity()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits());
}
static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
};
template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>
class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> epsilon()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
0x34, migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> quiet_NaN()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> max()
static constexpr fp8e5m2fnuz max()
{
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>(
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>());
return static_cast<fp8e5m2fnuz>(migraphx_fp8::F8_Max<fp8e5m2fnuz>());
}
// TODO figure out constexpr value
static migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> min()
static fp8e5m2fnuz min()
{
return static_cast<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>(float(-1.0f)) *
migraphx_fp8::F8_Max<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>();
}
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> lowest()
{
return migraphx_fp8::F8_Lowest<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>();
return static_cast<fp8e5m2fnuz>(float(-1.0f)) * migraphx_fp8::F8_Max<fp8e5m2fnuz>();
}
static constexpr fp8e5m2fnuz lowest() { return migraphx_fp8::F8_Lowest<fp8e5m2fnuz>(); }
static constexpr migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> infinity()
{
return migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>::from_bits());
}
static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); }
};
} // namespace migraphx_fp8
// =================================================================================================
// define numeric limits for the new data type
namespace std {
inline bool isfinite(migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> x) // NOLINT
inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{
return x.is_inf();
}
inline bool isfinite(migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> x) // NOLINT
inline bool isfinite(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{
return x.is_inf();
}
inline bool isnan(migraphx_fp8::float8<migraphx_fp8::f8_type::fp8> x) // NOLINT
inline bool isnan(migraphx_fp8::fp8e4m3fnuz x) // NOLINT
{
return x.is_nan();
}
inline bool isnan(migraphx_fp8::float8<migraphx_fp8::f8_type::bf8> x) // NOLINT
inline bool isnan(migraphx_fp8::fp8e5m2fnuz x) // NOLINT
{
return x.is_nan();
}
template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>
: public migraphx_fp8::numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::fp8>>
class numeric_limits<migraphx_fp8::fp8e4m3fnuz>
: public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e4m3fnuz>
{
};
template <>
class numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>
: public migraphx_fp8::numeric_limits<migraphx_fp8::float8<migraphx_fp8::f8_type::bf8>>
class numeric_limits<migraphx_fp8::fp8e5m2fnuz>
: public migraphx_fp8::numeric_limits<migraphx_fp8::fp8e5m2fnuz>
{
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment