Commit 3fa15bcb authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Use `enum class` instead of `enum`

parent 97bad9f9
...@@ -32,9 +32,9 @@ using bf8_fnuz_t = unsigned _BitInt(8); ...@@ -32,9 +32,9 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#endif #endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OFP8_CVT_FAST_PATH 1 #define CK_OCP_FP8_CVT_FAST_PATH 1
#else #else
#define CK_OFP8_CVT_FAST_PATH 0 #define CK_OCP_FP8_CVT_FAST_PATH 0
#endif #endif
typedef unsigned char fp8_storage_t; typedef unsigned char fp8_storage_t;
...@@ -42,7 +42,7 @@ typedef unsigned char fp8_storage_t; ...@@ -42,7 +42,7 @@ typedef unsigned char fp8_storage_t;
/** /**
* \brief Describes FP8 interpretation * \brief Describes FP8 interpretation
*/ */
enum ck_fp8_interpretation_t enum class ck_fp8_interpretation_t
{ {
CK_E4M3_OCP = 0, // OCP E4M3 CK_E4M3_OCP = 0, // OCP E4M3
CK_E5M2_OCP = 1, // OCP E5M2 CK_E5M2_OCP = 1, // OCP E5M2
...@@ -53,7 +53,7 @@ enum ck_fp8_interpretation_t ...@@ -53,7 +53,7 @@ enum ck_fp8_interpretation_t
/** /**
* \brief Describes saturation behavior * \brief Describes saturation behavior
*/ */
enum ck_saturation_t enum class ck_saturation_t
{ {
CK_NOSAT = 0, // No saturation - replace with NaN or Inf CK_NOSAT = 0, // No saturation - replace with NaN or Inf
CK_SATFINITE = 1, // Saturate to finite CK_SATFINITE = 1, // Saturate to finite
...@@ -250,11 +250,14 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v) ...@@ -250,11 +250,14 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
} val; } val;
val.i8val[0] = v; val.i8val[0] = v;
static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP || static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP, interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported"); "Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP)) if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{ {
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
} }
...@@ -269,11 +272,14 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) ...@@ -269,11 +272,14 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{ {
const auto i16val = bit_cast<uint16_t>(v); const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP || static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP, interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported"); "Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP)) if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{ {
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
} }
...@@ -295,8 +301,9 @@ struct f8_ocp_t ...@@ -295,8 +301,9 @@ struct f8_ocp_t
using data_type = fp8_storage_t; using data_type = fp8_storage_t;
data_type data; data_type data;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE; static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E4M3_OCP; static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E4M3_OCP;
static constexpr unsigned int we = 4; // exponent width static constexpr unsigned int we = 4; // exponent width
static constexpr unsigned int wm = 3; // mantissa width static constexpr unsigned int wm = 3; // mantissa width
...@@ -312,7 +319,7 @@ struct f8_ocp_t ...@@ -312,7 +319,7 @@ struct f8_ocp_t
__host__ explicit operator float() const __host__ explicit operator float() const
#endif #endif
{ {
#if CK_OFP8_CVT_FAST_PATH #if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data); return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else #else
return fp8_impl::cast_from_f8<float, wm, we, false>( return fp8_impl::cast_from_f8<float, wm, we, false>(
...@@ -326,7 +333,7 @@ struct f8_ocp_t ...@@ -326,7 +333,7 @@ struct f8_ocp_t
__host__ explicit operator _Float16() const __host__ explicit operator _Float16() const
#endif #endif
{ {
#if CK_OFP8_CVT_FAST_PATH #if CK_OCP_FP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data)); return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else #else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>( return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
...@@ -340,8 +347,9 @@ struct bf8_ocp_t ...@@ -340,8 +347,9 @@ struct bf8_ocp_t
using data_type = fp8_storage_t; using data_type = fp8_storage_t;
data_type data; data_type data;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE; static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E5M2_OCP; static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width static constexpr unsigned int wm = 2; // mantissa width
...@@ -442,7 +450,7 @@ struct non_native_vector_base<f8_ocp_t, 2> ...@@ -442,7 +450,7 @@ struct non_native_vector_base<f8_ocp_t, 2>
__host__ explicit operator float2_t() const __host__ explicit operator float2_t() const
#endif #endif
{ {
#if CK_OFP8_CVT_FAST_PATH #if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(d); return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(d);
#else #else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[0]), return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[0]),
...@@ -529,14 +537,16 @@ namespace fp8_impl { ...@@ -529,14 +537,16 @@ namespace fp8_impl {
// Assertions to check for supported conversion types // Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \ #define __assert_ocp_support(interp) \
{ \ { \
if(interp != CK_E4M3_OCP && interp != CK_E5M2_OCP) \ if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
{ \ { \
__hip_assert(false && "type is unsupported by current target device"); \ __hip_assert(false && "type is unsupported by current target device"); \
} \ } \
} }
#define __assert_fnuz_support(interp) \ #define __assert_fnuz_support(interp) \
{ \ { \
if(interp != CK_E4M3_FNUZ && interp != CK_E5M2_FNUZ) \ if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
{ \ { \
__hip_assert(false && "type is unsupported by current target device"); \ __hip_assert(false && "type is unsupported by current target device"); \
} \ } \
...@@ -574,14 +584,14 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = ...@@ -574,14 +584,14 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
if constexpr(saturate) if constexpr(saturate)
{ {
if constexpr(interpret == CK_E4M3_FNUZ) if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{ {
if((val.i32val & 0x7F800000) != 0x7F800000) if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping { /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
} }
} }
else if constexpr(interpret == CK_E4M3_OCP) else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{ // OCP type { // OCP type
if((val.i32val & 0x7F800000) != 0x7F800000) if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping { /// propagate NAN/INF, no clipping
...@@ -599,7 +609,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = ...@@ -599,7 +609,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
if constexpr(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
ival = (interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP) ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival; val.i32val = ival;
...@@ -607,7 +618,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = ...@@ -607,7 +618,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
} }
else else
{ // RNE CVT { // RNE CVT
ival = (interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP) ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
val.fval, val.fval,
...@@ -897,7 +909,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -897,7 +909,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
* \return fp8_storage_t * \return fp8_storage_t
*/ */
template <ck_fp8_interpretation_t interp, template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = CK_SATFINITE, ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false> bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH #if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...@@ -909,7 +921,8 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -909,7 +921,8 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
constexpr int seed = 1254739; constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f); rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
} }
return cast_to_f8_from_f32<interp, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
#else #else
#if CK_USE_OCP_FP8 #if CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...@@ -925,21 +938,41 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -925,21 +938,41 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f); rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
} }
if constexpr(interp == CK_E4M3_FNUZ) if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{ {
return cast_to_f8<float, 3, 4, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float,
3,
4,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
} }
else if constexpr(interp == CK_E5M2_FNUZ) else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
{ {
return cast_to_f8<float, 2, 5, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float,
2,
5,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
} }
else if constexpr(interp == CK_E4M3_OCP) else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{ {
return cast_to_f8<float, 3, 4, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float,
3,
4,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
} }
else if constexpr(interp == CK_E5M2_OCP) else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{ {
return cast_to_f8<float, 2, 5, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float,
2,
5,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
} }
else else
{ {
...@@ -959,7 +992,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -959,7 +992,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
* \return fp8_storage_t * \return fp8_storage_t
*/ */
template <ck_fp8_interpretation_t interp, template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = CK_SATFINITE, ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false> bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment