"sgl-kernel/csrc/vscode:/vscode.git/clone" did not exist on "19995dd78efd62e27a300911277da36b60b92cee"
Commit 2bd1b9cf authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Implement FP8OCP tests for half_t type conversions.

parent c76b765a
......@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
template <typename T, bool is_fnuz>
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we, bool clip = false)
template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{
// TODO: synchronize with f8_utils.hpp implementation for FNUZ
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
......@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we
}
#if CK_FP8_CVT_FAST_PATH
static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
{
union
{
......@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
} val;
val.i8val[0] = v;
float fval = (interpret == internal::CK_E4M3_FNUZ) || (interpret == internal::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
: __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
return fval;
static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP ||
interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == internal::CK_E4M3_FNUZ) || (interpret == internal::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
}
else
{
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
}
}
// The conversion function is from rocblas
......@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert half_t to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x half_t value
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
internal::__is_interpret_supported(interp);
#elif CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
#else
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const half_t x)
{
#endif
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
}
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
Inf are not supported. This gives us one additional number to represent.
NaN are represented by 1-0000-000 or 1-00000-00 */
......@@ -706,15 +740,31 @@ struct f8_ocp_t
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const {
__host__ __device__ explicit operator float() const
{
#else
__host__ explicit operator float() const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return internal::cast_to_f32_from_f8(this->data, default_interpret);
return internal::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return internal::cast_from_f8<float, false>(this->data, wm, we);
return internal::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator half_t
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator half_t() const {
#else
__host__ explicit operator half_t() const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return static_cast<half_t>(internal::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return internal::cast_from_f8<half_t, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
}; // namespace ck
......@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
return f8_ocp_t{
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// convert half_t to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{
internal::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with rounding to nearest even
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
......@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
internal::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
// convert half_t to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{internal::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
#if CK_USE_OCP_FP8
using f8_t = f8_ocp_t;
......
......@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic)
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP16Nearest) {}
TEST(FP8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-0.001953125f}; //-2^-9
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::internal::ocp_f8_is_nan(f8_nan.data));
}
TEST(FP8OCP, ConvertFP16Stochastic) {}
TEST(FP8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(min_subnorm_fp8));
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_NEAR(
type_convert<float>(half_t_zero),
type_convert<float>(type_convert<half_t>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm))),
min_subnorm_fp8);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::internal::ocp_f8_is_nan(f8_nan.data));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment