Commit 60942349 authored by Umang Yadav's avatar Umang Yadav
Browse files

Make FNUZ template param and add numeric limits

parent d7339e8a
......@@ -46,10 +46,6 @@
#define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
......@@ -90,14 +86,14 @@ enum class f8_type
template <typename T>
class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8>
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8<T>& y) = default;
MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default;
struct from_bits_t
{
};
......@@ -195,11 +191,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
......@@ -207,11 +203,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
......@@ -278,11 +274,9 @@ struct float8
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx::fp8::impl::
cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx::fp8::impl::
cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
/*
......@@ -296,7 +290,7 @@ struct float8
// check for zero
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x00;
}
......@@ -309,7 +303,7 @@ struct float8
// check for nan
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x80;
}
......@@ -333,7 +327,7 @@ struct float8
// check for inf
inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
{
if constexpr(MIGRAPHX_FP8_FNUZ)
if constexpr(FNUZ)
{
return data == 0x80;
}
......@@ -458,97 +452,139 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
return T{0xFF, T::from_bits()};
}
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8>;
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>
class numeric_limits<fp8e4m3fnuz>
{
public:
// TODO :figure out epsilon in Hex to make it constexpr
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>
epsilon()
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon()
{
return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN()
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>(
0x28, migraphx::fp8::float8<>::from_bits());
return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>
quiet_NaN()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max()
{
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min()
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits());
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>
max()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest()
{
return migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>();
return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits());
}
};
// TODO figure out Hex value
static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8> min()
template <>
class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon()
{
return fp8e4m3fn(0x20, fp8e4m3fn::from_bits());
}
// NOLINTNEXTLINE
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN()
{
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>(-1.0f) *
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>();
return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>
lowest()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max()
{
return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min()
{
return migraphx::fp8::F8_Lowest<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>>();
return fp8e4m3fn(0x08, fp8e4m3fn::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>
infinity()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest()
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::fp8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits());
return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits());
}
};
template <>
class numeric_limits<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>
class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
epsilon()
static constexpr bool has_infinity = false;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon()
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>(
0x34, migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
quiet_NaN()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d,
migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
max()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max()
{
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>(
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>());
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
}
// TODO figure out constexpr value
static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8> min()
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min()
{
return static_cast<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>(float(-1.0f)) *
migraphx::fp8::F8_Max<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>();
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
lowest()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest()
{
return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits());
}
};
template <>
class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon()
{
return migraphx::fp8::F8_Lowest<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>>();
return fp8e5m2(0x34, fp8e5m2::from_bits());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN()
{
return fp8e5m2(0xFF, fp8e5m2::from_bits());
} // NOLINT
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max()
{
return fp8e5m2(0x7B, fp8e5m2::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min()
{
return fp8e5m2(0x4, fp8e5m2::from_bits());
}
static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>
infinity()
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest()
{
return fp8e5m2(0xFB, fp8e5m2::from_bits());
}
// 7C and FC both are infinity
static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity()
{
return migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>(
MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c,
migraphx::fp8::float8<migraphx::fp8::f8_type::bf8>::from_bits());
return fp8e5m2(0x7C, fp8e5m2::from_bits());
}
};
/*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment