Commit c76b765a authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Implement FP8OCP test for stochastic rounding mode.

parent d40d1ff1
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/statically_indexed_array.hpp" #include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/random_gen.hpp"
#ifdef CK_USE_FNUZ_FP8 #ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1 #define CK_USE_FNUZ_FP8 1
...@@ -240,7 +241,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -240,7 +241,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
} }
// First need to check if it is normal or denorm as there is a difference of // First need to check if it is normal or denorm as there is a difference of
// implict 1 Then need to adjust the exponent to align with the F8 exponent, // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably // to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again // need to check whether there is carry and adjust exponent and mantissa again
...@@ -275,7 +276,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -275,7 +276,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
{ {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal /* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1, actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1. Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent; exponent_diff = f8_denormal_act_exponent - act_exponent;
...@@ -303,7 +304,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -303,7 +304,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
else if(exponent_diff == -1) else if(exponent_diff == -1)
mantissa <<= -exponent_diff; mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << mfmt); bool implicit_one = mantissa & (1ull << mfmt);
// if there is no implict 1, it means the f8 is denormal and need to adjust // if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent // to denorm exponent
f8_exponent = f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
...@@ -530,9 +531,8 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret) ...@@ -530,9 +531,8 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
// The conversion function is from rocblas // The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
template <bool stochastic_rounding = false> template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
static __device__ fp8_storage_t static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, unsigned int rng = 0)
{ {
fp8_storage_t i8data; fp8_storage_t i8data;
union union
...@@ -545,9 +545,9 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u ...@@ -545,9 +545,9 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
unsigned int ival = 0; unsigned int ival = 0;
val.fval = v; val.fval = v;
if(saturate) if constexpr(saturate)
{ {
if(interpret == CK_E4M3_FNUZ) if constexpr(interpret == CK_E4M3_FNUZ)
{ {
if((val.i32val & 0x7F800000) != 0x7F800000) if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping { /// propagate NAN/INF, no clipping
...@@ -570,7 +570,7 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u ...@@ -570,7 +570,7 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
} }
} }
if(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
ival = (interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP) ival = (interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
...@@ -597,43 +597,59 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u ...@@ -597,43 +597,59 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
/** /**
* \brief convert float to @p fp8_storage_t * \brief convert float to @p fp8_storage_t
* *
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8 * \tparam sat saturation of fp8
* \param f float number * \param f float number
* \param interp interpretation of fp8
* \return fp8_storage_t * \return fp8_storage_t
*/ */
template <ck_saturation_t sat = CK_SATFINITE> template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH #if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
cvt_float_to_fp8(const float f, const ck_fp8_interpretation_t interp)
{ {
internal::__is_interpret_supported(interp); internal::__is_interpret_supported(interp);
return internal::cast_to_f8_from_f32<false>(f, sat == CK_SATFINITE, interp); uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return internal::cast_to_f8_from_f32<interp, sat == CK_SATFINITE, stochastic_rounding>(f, rng);
#else #else
#if CK_USE_OCP_FP8 #if CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
cvt_float_to_fp8(const float f, const ck_fp8_interpretation_t interp)
{ {
#else #else
__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f, __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
const ck_fp8_interpretation_t interp)
{ {
#endif #endif
if(interp == CK_E4M3_FNUZ) uint32_t rng = 0;
if constexpr(stochastic_rounding)
{ {
return internal::cast_to_f8<float, 3, 4, true, sat == CK_SATFINITE>(f); constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == CK_E4M3_FNUZ)
{
return internal::cast_to_f8<float, 3, 4, true, sat == CK_SATFINITE, stochastic_rounding>(
f, rng);
} }
else if(interp == CK_E5M2_FNUZ) else if(interp == CK_E5M2_FNUZ)
{ {
return internal::cast_to_f8<float, 2, 5, true, sat == CK_SATFINITE>(f); return internal::cast_to_f8<float, 2, 5, true, sat == CK_SATFINITE, stochastic_rounding>(
f, rng);
} }
else if(interp == CK_E4M3_OCP) else if(interp == CK_E4M3_OCP)
{ {
return internal::cast_to_f8<float, 3, 4, false, sat == CK_SATFINITE>(f); return internal::cast_to_f8<float, 3, 4, false, sat == CK_SATFINITE, stochastic_rounding>(
f, rng);
} }
else if(interp == CK_E5M2_OCP) else if(interp == CK_E5M2_OCP)
{ {
return internal::cast_to_f8<float, 2, 5, false, sat == CK_SATFINITE>(f); return internal::cast_to_f8<float, 2, 5, false, sat == CK_SATFINITE, stochastic_rounding>(
f, rng);
} }
else else
{ {
...@@ -734,7 +750,20 @@ template <> ...@@ -734,7 +750,20 @@ template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x) inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
{ {
return f8_ocp_t{ return f8_ocp_t{
internal::cvt_float_to_fp8<f8_ocp_t::default_saturation>(x, f8_ocp_t::default_interpret)}; internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
} }
#if CK_USE_OCP_FP8 #if CK_USE_OCP_FP8
......
...@@ -77,7 +77,58 @@ TEST(FP8OCP, ConvertFP32Nearest) ...@@ -77,7 +77,58 @@ TEST(FP8OCP, ConvertFP32Nearest)
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f); ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
} }
TEST(FP8OCP, ConvertFP32Stochastic) {} TEST(FP8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(max_f8_t_float)), 0.0f);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// smallest normal fp8 value to fp8 and back, check if holds
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr auto min_subnorm_fp8 = -0.001953125f; //-2^-9
ASSERT_NEAR(
min_subnorm_fp8, type_convert<float>(f8_convert_sr<f8_ocp_t>(min_subnorm_fp8)), 0.0f);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
ASSERT_NEAR(
0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm)), 0.001953125f);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP16Nearest) {} TEST(FP8OCP, ConvertFP16Nearest) {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment