"...composable_kernel.git" did not exist on "2ffa27083848cfb3d1a139c794296542c54f2fed"
Commit a997a51d authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Refactor f8-related type_converts

parent 62f6a1ce
......@@ -106,6 +106,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
......@@ -186,7 +188,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
......@@ -194,13 +195,14 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#endif
}
// convert fp32 to fp8
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else // defined CK_USE_SR_F8_CONVERSION
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
......@@ -224,7 +226,80 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
#endif // defined CK_USE_SR_F8_CONVERSION
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_rne<f8_t>(x);
#endif
}
// convert fp8 to fp32
......@@ -279,20 +354,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else // defined CK_USE_SR_F8_CONVERSION
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_nre<f8_t>(x);
#endif
#endif // defined CK_USE_SR_F8_CONVERSION
}
// convert fp8 to fp16
......@@ -314,29 +378,9 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else // defined CK_USE_SR_F8_CONVERSION
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_rne<bf8_t>(x);
#endif
#endif // defined CK_USE_SR_F8_CONVERSION
}
// convert bf8 to fp32
......@@ -361,20 +405,9 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else // defined CK_USE_SR_F8_CONVERSION
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_convert_rne<bf8_t>(x);
#endif
#endif // defined CK_USE_SR_F8_CONVERSION
}
// convert bf8 to fp16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment