Commit 2bd1b9cf authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Implement FP8OCP tests for half_t type conversions.

parent c76b765a
...@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// The conversion function is from rocblas // The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well // This has been modified to handle double types as well
template <typename T, bool is_fnuz> template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we, bool clip = false) __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{ {
// TODO: synchronize with f8_utils.hpp implementation for FNUZ
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value; constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value; constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value; constexpr bool is_double = __hip_internal::is_same<T, double>::value;
...@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we ...@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we
} }
#if CK_FP8_CVT_FAST_PATH #if CK_FP8_CVT_FAST_PATH
static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret) template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
{ {
union union
{ {
...@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret) ...@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
} val; } val;
val.i8val[0] = v; val.i8val[0] = v;
float fval = (interpret == internal::CK_E4M3_FNUZ) || (interpret == internal::CK_E4M3_OCP) static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP ||
? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0) interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP,
: __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); "Only FNUZ and OCP interpretations are supported");
return fval;
if constexpr((interpret == internal::CK_E4M3_FNUZ) || (interpret == internal::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
}
else
{
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
}
} }
// The conversion function is from rocblas // The conversion function is from rocblas
...@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
#endif // CK_FP8_CVT_FAST_PATH #endif // CK_FP8_CVT_FAST_PATH
} }
/**
* \brief convert half_t to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x half_t value
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
internal::__is_interpret_supported(interp);
#elif CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
#else
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
#endif
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
}
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned. /* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
Inf are not supported. This gives us one additional number to represent. Inf are not supported. This gives us one additional number to represent.
NaN are represented by 1-0000-000 or 1-00000-00 */ NaN are represented by 1-0000-000 or 1-00000-00 */
...@@ -706,15 +740,31 @@ struct f8_ocp_t ...@@ -706,15 +740,31 @@ struct f8_ocp_t
} }
#if CK_USE_OCP_FP8 #if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const { __host__ __device__ explicit operator float() const
{
#else #else
__host__ explicit operator float() const __host__ explicit operator float() const
{ {
#endif #endif
#if CK_FP8_CVT_FAST_PATH #if CK_FP8_CVT_FAST_PATH
return internal::cast_to_f32_from_f8(this->data, default_interpret); return internal::cast_to_f32_from_f8<default_interpret>(this->data);
#else #else
return internal::cast_from_f8<float, false>(this->data, wm, we); return internal::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator half_t
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator half_t() const {
#else
__host__ explicit operator half_t() const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return static_cast<half_t>(internal::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return internal::cast_from_f8<half_t, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif #endif
} }
}; // namespace ck }; // namespace ck
...@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x) ...@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
return f8_ocp_t{ return f8_ocp_t{
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)}; internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
} }
// convert half_t to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{
internal::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// Declare a template function for fp8 conversion using RNE // Declare a template function for fp8 conversion using RNE
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with rounding to nearest even // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x) inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{ {
...@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x) ...@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>( internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)}; x)};
} }
// convert half_t to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{internal::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
#if CK_USE_OCP_FP8 #if CK_USE_OCP_FP8
using f8_t = f8_ocp_t; using f8_t = f8_ocp_t;
......
...@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic) ...@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic)
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f); ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
} }
TEST(FP8OCP, ConvertFP16Nearest) {} TEST(FP8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-0.001953125f}; //-2^-9
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::internal::ocp_f8_is_nan(f8_nan.data));
}
TEST(FP8OCP, ConvertFP16Stochastic) {} TEST(FP8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(min_subnorm_fp8));
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_NEAR(
type_convert<float>(half_t_zero),
type_convert<float>(type_convert<half_t>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm))),
min_subnorm_fp8);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::internal::ocp_f8_is_nan(f8_nan.data));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment