Commit 60942349 authored by Umang Yadav's avatar Umang Yadav
Browse files

Make FNUZ template param and add numeric limits

parent d7339e8a
...@@ -46,10 +46,6 @@ ...@@ -46,10 +46,6 @@
#define MIGRAPHX_HIP_DEVICE __device__ #define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default // We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#if defined(MIGRAPHX_JIT_USE_HIPRTC) #if defined(MIGRAPHX_JIT_USE_HIPRTC)
...@@ -90,14 +86,14 @@ enum class f8_type ...@@ -90,14 +86,14 @@ enum class f8_type
template <typename T> template <typename T>
class numeric_limits; class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8> template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8 struct float8
{ {
uint8_t data; uint8_t data;
// default constructor // default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default; MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default;
// default copy constructor // default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8<T>& y) = default; MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default;
struct from_bits_t struct from_bits_t
{ {
}; };
...@@ -195,11 +191,11 @@ struct float8 ...@@ -195,11 +191,11 @@ struct float8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl:: data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl:: data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
} }
...@@ -207,11 +203,11 @@ struct float8 ...@@ -207,11 +203,11 @@ struct float8
{ {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl:: data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING #else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl:: data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping} #endif // rocblas_F8_downcast_clipping}
} }
...@@ -278,11 +274,9 @@ struct float8 ...@@ -278,11 +274,9 @@ struct float8
{ {
if constexpr(T == migraphx::fp8::f8_type::fp8) if constexpr(T == migraphx::fp8::f8_type::fp8)
{ {
return migraphx::fp8::impl:: return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} // else } // else
return migraphx::fp8::impl:: return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
} }
/* /*
...@@ -296,7 +290,7 @@ struct float8 ...@@ -296,7 +290,7 @@ struct float8
// check for zero // check for zero
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x00; return data == 0x00;
} }
...@@ -309,7 +303,7 @@ struct float8 ...@@ -309,7 +303,7 @@ struct float8
// check for nan // check for nan
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x80; return data == 0x80;
} }
...@@ -333,7 +327,7 @@ struct float8 ...@@ -333,7 +327,7 @@ struct float8
// check for inf // check for inf
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
{ {
if constexpr(MIGRAPHX_FP8_FNUZ) if constexpr(FNUZ)
{ {
return data == 0x80; return data == 0x80;
} }
...@@ -458,97 +452,139 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() ...@@ -458,97 +452,139 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
return T{0xFF, T::from_bits()}; return T{0xFF, T::from_bits()};
} }
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8>; // https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
template <> template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>> class numeric_limits<fp8e4m3fnuz>
{ {
public: public:
// TODO :figure out epsilon in Hex to make it constexpr static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon()
epsilon() {
return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>( return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits());
0x28, migraphx::fp8::float8<>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max()
quiet_NaN() {
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>( return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest()
max()
{ {
return migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>(); return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits());
} }
};
// TODO figure out Hex value template <>
static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> min() class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon()
{
return fp8e4m3fn(0x20, fp8e4m3fn::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN()
{ {
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>(-1.0f) * return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits());
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>();
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max()
lowest() {
return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min()
{ {
return migraphx::fp8::F8_Lowest<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>(); return fp8e4m3fn(0x08, fp8e4m3fn::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest()
infinity()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>( return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits());
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits());
} }
}; };
template <> template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>> class numeric_limits<fp8e5m2fnuz>
{ {
public: public:
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> static constexpr bool has_infinity = false;
epsilon() static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>( return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits());
0x34, migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT
quiet_NaN()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>( return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max()
max()
{ {
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>( return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>());
} }
// TODO figure out constexpr value // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> min() // this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min()
{ {
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>(float(-1.0f)) * return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>();
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
lowest() static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest()
{
return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits());
}
};
template <>
class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon()
{ {
return migraphx::fp8::F8_Lowest<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>(); return fp8e5m2(0x34, fp8e5m2::from_bits());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN()
{
return fp8e5m2(0xFF, fp8e5m2::from_bits());
} // NOLINT
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max()
{
return fp8e5m2(0x7B, fp8e5m2::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min()
{
return fp8e5m2(0x4, fp8e5m2::from_bits());
} }
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest()
infinity() {
return fp8e5m2(0xFB, fp8e5m2::from_bits());
}
// 7C and FC both are infinity
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity()
{ {
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>( return fp8e5m2(0x7C, fp8e5m2::from_bits());
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
} }
}; };
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment