Unverified Commit c4a05057 authored by Andriy Roshchenko's avatar Andriy Roshchenko Committed by GitHub
Browse files

[MX FP8] Add Scaled Type Convert Functions for OCP FP8/BF8 data types (#271)

* Move scaled_type_convert functions to a separate header

* Introduce MX data tests

* Build MX tests only on relevant architectures

* Refactor E8M0 scale implementation

* Fix `config.h` typo

* Cleanup deprecated symbols

* Refactor `amd_ck_fp8.hpp`

* `scaled_type_convert` for `f8_ocp_t`

* Implement test for MX FP8 scaled type convert

* Implement test for MX BF8 scaled type convert

* Scaled type convert for vectors of 2 FP8 elements

* Scaled type convert for vectors of 16 FP8 elements

* Implementation of scaled conversion from F32 to F8

* Add tests for scaled conversions from FP32 to FP8

* Add documentation to the test functions

* Implementation of scaled conversion from F32x2 to F8x2

* Implementation of scaled conversion from F32x16 to F8x16

* Implementation of scaled conversion from F32x32 to F8x32

* Implementation of scaled conversion from F8x32 to F32x32

* Verified on the emulator
parent 23e2309d
...@@ -18,25 +18,6 @@ ...@@ -18,25 +18,6 @@
#define CK_USE_OCP_FP8 0 #define CK_USE_OCP_FP8 0
#endif #endif
namespace {
// https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx950__)) && \ defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__ __HIP_DEVICE_COMPILE__
...@@ -51,6 +32,11 @@ using bf8_fnuz_t = unsigned _BitInt(8); ...@@ -51,6 +32,11 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#define CK_OCP_FP8_CVT_FAST_PATH 0 #define CK_OCP_FP8_CVT_FAST_PATH 0
#endif #endif
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
typedef unsigned char fp8_storage_t; typedef unsigned char fp8_storage_t;
/** /**
...@@ -205,10 +191,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -205,10 +191,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
} }
} }
typename conditional< typename std::conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval; typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
retval;
if constexpr(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
...@@ -301,7 +288,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) ...@@ -301,7 +288,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
} }
} }
#endif #endif
} // namespace fp8_impl } // namespace fp8_impl
...@@ -551,10 +537,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -551,10 +537,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename conditional< using T_bitwise = typename std::conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type; typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x); T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise}; unsigned long long x{x_bitwise};
......
...@@ -11,6 +11,12 @@ namespace ck { ...@@ -11,6 +11,12 @@ namespace ck {
* @brief Unsigned representation of a conventional biased Float32 exponent. * @brief Unsigned representation of a conventional biased Float32 exponent.
* *
* bias = 127; * bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127 * E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127 * E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN * E8M0_NAN = 0b11111111; => NaN
......
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
unsigned int rng = 0,
float scale = 1.0f)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
} val;
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
fp8_storage_t v4i8[4];
} ret{};
// unsigned int ival = 0;
val.fval = v;
if constexpr(stochastic_rounding)
{
ret.ival =
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
: __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
i8data = ret.v4i8[0];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
i8data = ret.v4i8[0];
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
unsigned int rng = 0,
float scale = 1.0f)
{
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
} ret{};
if constexpr(stochastic_rounding)
{
fp8x2_storage_t f8x2;
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
else
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
return f8x2;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
return ret.v2f8x2(Number<0>{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
} // namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
namespace ck { namespace ck {
...@@ -11,6 +12,194 @@ namespace ck { ...@@ -11,6 +12,194 @@ namespace ck {
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x); __host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
// convert f8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert bf8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale,
bf8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale, bf8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale,
f8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale, f8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#endif
{
union
{
f8x16_ocp_t f8_1x16;
f8x2_ocp_t f8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, f8x2_ocp_t>(scale, in.f8_2x8[i]);
});
return out.float_1x16;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#endif
{
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, bf8x2_ocp_t>(scale, in.bf8_2x8[i]);
});
return out.float_1x16;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#endif
{
union
{
f8x32_ocp_t f8_1x32;
f8x16_ocp_t f8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, f8x16_ocp_t>(scale, in.f8_16x2[i]);
});
return out.float_1x32;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#endif
{
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, bf8x16_ocp_t>(scale, in.bf8_16x2[i]);
});
return out.float_1x32;
}
// convert fp4 to fp32 // convert fp4 to fp32
template <> template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x) inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
...@@ -29,6 +218,104 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t sc ...@@ -29,6 +218,104 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t sc
#endif #endif
} }
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to fp8x2
template <>
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to bf8x2
template <>
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32 // convert vector of 2 fp4 to vector of 2 fp32
template <> template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale, inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
......
...@@ -12,6 +12,7 @@ endif() ...@@ -12,6 +12,7 @@ endif()
add_custom_target(test_fp8) add_custom_target(test_fp8)
if (CK_USE_OCP_FP8) if (CK_USE_OCP_FP8)
# add test for ocp data types
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp) add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility) target_link_libraries(test_fp8_ocp PRIVATE utility)
...@@ -62,13 +63,24 @@ if(GPU_TARGETS MATCHES "gfx950") ...@@ -62,13 +63,24 @@ if(GPU_TARGETS MATCHES "gfx950")
endif() endif()
add_dependencies(test_mx_data_types test_bf6) add_dependencies(test_mx_data_types test_bf6)
add_gtest_executable(test_mx_fp8 test_mx_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_fp8 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_mx_fp8)
add_gtest_executable(test_mx_bf8 test_mx_bf8.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_bf8 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_mx_bf8)
add_gtest_executable(test_e8m0 test_e8m0.cpp) add_gtest_executable(test_e8m0 test_e8m0.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_e8m0 PRIVATE utility) target_link_libraries(test_e8m0 PRIVATE utility)
endif() endif()
add_dependencies(test_mx_data_types test_e8m0) add_dependencies(test_mx_data_types test_e8m0)
endif() endif()
add_gtest_executable(test_custom_type test_custom_type.cpp) add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_custom_type PRIVATE utility) target_link_libraries(test_custom_type PRIVATE utility)
......
...@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest) ...@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest)
float neg_float = -0.015625f; //-2^-6 float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm fp8 value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f; // 2^-8
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds // min subnorm fp8 value to fp8 and back, check if holds
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::bf8_ocp_t;
using ck::bf8x16_ocp_t;
using ck::bf8x2_ocp_t;
using ck::bf8x32_ocp_t;
using ck::e8m0_bexp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of BF8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from BF8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are
* stored in memory sequentially with BF8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and BF8 values. [256x256]
* - Vector conversions bf8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> bf8x2 rne. [2]
* - Vector conversions f32x2 -> bf8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_bf8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and BF8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// bf8x2 -> f32x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
auto scale = e8m0_bexp_t(8.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale, bf8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> bf8x2
f32x2 = {-8.0f, 4.0f};
auto scale2 = e8m0_bexp_t(2.0f);
bf8x2 = mxf8_convert_rne<bf8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
bf8x2 = mxf8_convert_sr<bf8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(
std::numeric_limits<float>::infinity(), 2.0f)); // => BF8 Inf on device
if(i >= N)
{
return;
}
// 31000/0.5 > 57344 => BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(31000.0f, 0.5f));
if(i >= N)
{
return;
}
// -31000/0.5 < -57344 => -BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(-31000.0f, 0.5f));
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(powf(2.0f, 16.0f), 4.0f)); // 2^16/4 = 65536/4
if(i >= N)
{
return;
}
}
TEST(MXBF8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_bf8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); // -NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); // -NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); // -NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); // -inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
// V = X * P; X, P - finite
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_bf8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_bf8_scaled_convert(N, p_test, p_completed);
}
TEST(MXBF8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); //-NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); //-NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); //-NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); //-inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns Infs, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i) { return powf(-1.0f, i) * powf(2.0f, i); }
__global__ void test_mx_bf8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x16_ocp_t bf8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
bf8x16 = scaled_type_convert<bf8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x16.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x16ToBF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_bf8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_rne<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_bf8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_sr<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
bf8x32.AsType<bf8_ocp_t>()(ii) = type_convert<bf8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, bf8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXBF8, DeviceBF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f8_ocp_t;
using ck::f8x16_ocp_t;
using ck::f8x2_ocp_t;
using ck::f8x32_ocp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::fp8_impl::fp8x2_storage_t;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of FP8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from FP8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are
* stored in memory sequentially with FP8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and FP8 values. [256x256]
* - Vector conversions f8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> f8x2 rne. [2]
* - Vector conversions f32x2 -> f8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_fp8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and FP8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// f8x2 -> f32x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
auto scale2 = e8m0_bexp_t(2.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, fp8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> f8x2
f32x2 = {-8.0f, 4.0f};
fp8x2 = mxf8_convert_rne<f8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
fp8x2 = mxf8_convert_sr<f8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
// Inf/2 > 448 => NaN on device
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity(), 2.0f));
if(i >= N)
{
return;
}
// 256/0.5 > 448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(256.0f, 0.5f));
if(i >= N)
{
return;
}
// -256/0.5 < -448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(-256.0f, 0.5f));
if(i >= N)
{
return;
}
// proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5
p_test[i++] =
type_convert<float>(mxf8_convert_rne<f8_ocp_t>(10000.0f, 32.0f)); // 10000/32 = 312.5
if(i >= N)
{
return;
}
}
TEST(MXFP8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_fp8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_fp8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_fp8_scaled_convert(N, p_test, p_completed);
}
TEST(MXFP8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns NaN, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i)
{
return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8);
}
__global__ void test_mx_fp8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x16_ocp_t fp8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
fp8x16 = scaled_type_convert<ck::f8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(fp8x16.AsType<f8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXFP8, DeviceF32x16ToF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_fp8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_rne<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_fp8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_sr<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
fp8x32.AsType<f8_ocp_t>()(ii) = type_convert<f8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, fp8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXFP8, DeviceF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment