Commit 48d58131 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Rename E8M0 type

parent b1ad4b4f
...@@ -14,15 +14,15 @@ using f4_t = unsigned _BitInt(4); ...@@ -14,15 +14,15 @@ using f4_t = unsigned _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
struct e8m0_scale_t struct e8m0_bexp_t
{ {
// E8M0 scale is biased // E8M0 scale is biased
using type = uint8_t; using type = uint8_t;
type data; type data;
constexpr e8m0_scale_t() : data{type{}} {} constexpr e8m0_bexp_t() : data{type{}} {}
constexpr e8m0_scale_t(type init) : data{init} {} constexpr e8m0_bexp_t(type init) : data{init} {}
bool operator==(const e8m0_scale_t& other) const { return (data == other.data); } bool operator==(const e8m0_bexp_t& other) const { return (data == other.data); }
}; };
struct f4x2_pk_t struct f4x2_pk_t
...@@ -1813,33 +1813,30 @@ struct NumericLimits<f4_t> ...@@ -1813,33 +1813,30 @@ struct NumericLimits<f4_t>
}; };
template <> template <>
struct NumericLimits<e8m0_scale_t> struct NumericLimits<e8m0_bexp_t>
{ {
static constexpr e8m0_scale_t binary_min = 0x00; // 0b00000000 static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_scale_t binary_max = 0xFE; // 0b11111110 static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_scale_t binary_qnan = 0xFF; // 0b11111111 static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_scale_t binary_1 = 0x7F; // 0b01111111 static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_scale_t binary_2 = 0x80; // 0b10000000 static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_scale_t binary_3 = 0x82; // 0b10000010 static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_scale_t binary_135 = 0x87; // 0b10000111 static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_scale_t binary_142 = 0x8E; // 0b10001110 static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_scale_t Min() { return e8m0_scale_t(binary_min); } __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
__host__ __device__ static constexpr e8m0_scale_t Max() { return e8m0_scale_t(binary_max); } __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
__host__ __device__ static constexpr e8m0_scale_t QuietNaN() __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
{ {
return e8m0_scale_t(binary_qnan); return e8m0_bexp_t(binary_135);
} }
__host__ __device__ static constexpr e8m0_scale_t Binary_1() { return e8m0_scale_t(binary_1); } __host__ __device__ static constexpr e8m0_bexp_t Binary_142()
__host__ __device__ static constexpr e8m0_scale_t Binary_2() { return e8m0_scale_t(binary_2); }
__host__ __device__ static constexpr e8m0_scale_t Binary_3() { return e8m0_scale_t(binary_3); }
__host__ __device__ static constexpr e8m0_scale_t Binary_135()
{ {
return e8m0_scale_t(binary_135); return e8m0_bexp_t(binary_142);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_142()
{
return e8m0_scale_t(binary_142);
} }
}; };
...@@ -1944,7 +1941,7 @@ struct NumericUtils<f4_t> ...@@ -1944,7 +1941,7 @@ struct NumericUtils<f4_t>
}; };
template <> template <>
struct NumericUtils<e8m0_scale_t> struct NumericUtils<e8m0_bexp_t>
{ {
static constexpr int exp = 8; static constexpr int exp = 8;
static constexpr int mant = 0; static constexpr int mant = 0;
......
...@@ -8,24 +8,24 @@ ...@@ -8,24 +8,24 @@
namespace ck::utils { namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_scale_t const scale) __host__ __device__ inline float cast_to_float(e8m0_bexp_t const bexp)
{ {
// TODO: check performance and try bit shift impl // TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(scale) - NumericUtils<e8m0_scale_t>::bias); return std::powf(2, bit_cast<uint8_t>(bexp) - NumericUtils<e8m0_bexp_t>::bias);
} }
__host__ __device__ inline e8m0_scale_t cast_from_float(float const scale) __host__ __device__ inline e8m0_bexp_t cast_from_float(float const scale)
{ {
uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask; uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23); return static_cast<uint8_t>(e >> 23);
} }
template <> template <>
__host__ __device__ inline int get_exponent_value<e8m0_scale_t>(e8m0_scale_t x) __host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{ {
x.data >>= NumericUtils<e8m0_scale_t>::mant; x.data >>= NumericUtils<e8m0_bexp_t>::mant;
x.data &= ((1 << NumericUtils<e8m0_scale_t>::exp) - 1); x.data &= ((1 << NumericUtils<e8m0_bexp_t>::exp) - 1);
return static_cast<int>(x.data); return static_cast<int>(x.data);
} }
......
...@@ -9,16 +9,16 @@ ...@@ -9,16 +9,16 @@
namespace ck::utils { namespace ck::utils {
template <> template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_scale_t const scale, __host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]]) f4_t const dataBytes [[maybe_unused]])
{ {
// no need to check for data as it does not have NaN representation // no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_scale_t>::QuietNaN(); return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
} }
// no infinity representation in ocp_e2m1_mxfp4 will always return false // no infinity representation in ocp_e2m1_mxfp4 will always return false
template <> template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_unused]], __host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]]) f4_t const data [[maybe_unused]])
{ {
// no inf representation for ocp_e2m1_mxfp4 // no inf representation for ocp_e2m1_mxfp4
...@@ -26,7 +26,7 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_un ...@@ -26,7 +26,7 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_un
} }
template <> template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t const data) __host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{ {
if(is_nan<f4_t>(scale, data)) if(is_nan<f4_t>(scale, data))
return false; return false;
...@@ -38,7 +38,7 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con ...@@ -38,7 +38,7 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con
} }
template <> template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t const data) __host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{ {
if(is_nan<f4_t>(scale, data)) if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN(); return std::numeric_limits<float>::quiet_NaN();
...@@ -48,7 +48,7 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c ...@@ -48,7 +48,7 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c
f4_t prepared_data = data & 0b00001111; f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_scale_t>(scale); int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp); return convert_to_float<f4_t>(prepared_data, scale_exp);
} }
...@@ -73,7 +73,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value) ...@@ -73,7 +73,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
f4_t res = convert_to_type<f4_t>(value); f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) < if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm()) NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask; : NumericUtils<f4_t>::positive_zero_mask;
...@@ -98,7 +98,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32 ...@@ -98,7 +98,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
f4_t res = convert_to_type_sr<f4_t>(value, seed); f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) < if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm()) NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask; : NumericUtils<f4_t>::positive_zero_mask;
......
...@@ -18,13 +18,13 @@ inline bool getDataHasInf() ...@@ -18,13 +18,13 @@ inline bool getDataHasInf()
} }
template <typename T> template <typename T>
__host__ __device__ inline bool is_zero(e8m0_scale_t const scale, T const data); __host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T> template <typename T>
__host__ __device__ inline bool is_nan(e8m0_scale_t const scale, T const data); __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T> template <typename T>
__host__ __device__ inline bool is_inf(e8m0_scale_t const scale, T const data); __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T> template <typename T>
__host__ __device__ inline int get_exponent_value(T x) __host__ __device__ inline int get_exponent_value(T x)
...@@ -79,13 +79,13 @@ __host__ __device__ float convert_to_float(T data, int scale_exp) ...@@ -79,13 +79,13 @@ __host__ __device__ float convert_to_float(T data, int scale_exp)
float data_value = d_sign * d_exp * d_mant; float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow( float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_scale_t>::bias)))); 2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value; return data_value * scale_value;
} }
template <typename T> template <typename T>
__host__ __device__ inline float to_float(e8m0_scale_t const scale, T const data); __host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T> template <typename T>
__host__ __device__ T sat_convert_to_type(float value); __host__ __device__ T sat_convert_to_type(float value);
......
...@@ -1000,7 +1000,7 @@ inline __host__ __device__ float type_convert<float, f4_t>(f4_t x) ...@@ -1000,7 +1000,7 @@ inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0); float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0]; return float_values.float_array[0];
#else #else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x); return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif #endif
} }
...@@ -1018,8 +1018,8 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x) ...@@ -1018,8 +1018,8 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float scale = 1.0f; float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else #else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(1)), float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x.unpack(1)),
utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(0))}; utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x.unpack(0))};
return ret; return ret;
#endif #endif
} }
...@@ -1153,72 +1153,72 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x) ...@@ -1153,72 +1153,72 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
f4x32_t f4x32_array; f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)}; } f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop // TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(0)); f4_values.f4x2_array[0].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(1)); f4_values.f4x2_array[0].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(0)); f4_values.f4x2_array[1].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(1)); f4_values.f4x2_array[1].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(0)); f4_values.f4x2_array[2].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(1)); f4_values.f4x2_array[2].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(0)); f4_values.f4x2_array[3].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(1)); f4_values.f4x2_array[3].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(0)); f4_values.f4x2_array[4].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(1)); f4_values.f4x2_array[4].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(0)); f4_values.f4x2_array[5].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(1)); f4_values.f4x2_array[5].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(0)); f4_values.f4x2_array[6].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(1)); f4_values.f4x2_array[6].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(0)); f4_values.f4x2_array[7].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(1)); f4_values.f4x2_array[7].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(0)); f4_values.f4x2_array[8].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(1)); f4_values.f4x2_array[8].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(0)); f4_values.f4x2_array[9].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(1)); f4_values.f4x2_array[9].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(0)); f4_values.f4x2_array[10].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(1)); f4_values.f4x2_array[10].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(0)); f4_values.f4x2_array[11].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(1)); f4_values.f4x2_array[11].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(0)); f4_values.f4x2_array[12].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(1)); f4_values.f4x2_array[12].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(0)); f4_values.f4x2_array[13].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(1)); f4_values.f4x2_array[13].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(0)); f4_values.f4x2_array[14].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(1)); f4_values.f4x2_array[14].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(0)); f4_values.f4x2_array[15].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(1)); f4_values.f4x2_array[15].unpack(1));
return float_values.float32_array; return float_values.float32_array;
...@@ -1226,24 +1226,24 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x) ...@@ -1226,24 +1226,24 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
} }
template <> template <>
inline __host__ __device__ float type_convert<float, e8m0_scale_t>(e8m0_scale_t scale) inline __host__ __device__ float type_convert<float, e8m0_bexp_t>(e8m0_bexp_t scale)
{ {
return utils::cast_to_float(scale); return utils::cast_to_float(scale);
} }
template <> template <>
inline __host__ __device__ e8m0_scale_t type_convert<e8m0_scale_t, float>(float scale) inline __host__ __device__ e8m0_bexp_t type_convert<e8m0_bexp_t, float>(float scale)
{ {
return utils::cast_from_float(scale); return utils::cast_from_float(scale);
} }
// Declare a template function for scaled conversion // Declare a template function for scaled conversion
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x); __host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
// convert fp4 to fp32 // convert fp4 to fp32
template <> template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t x) inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
union union
...@@ -1261,7 +1261,7 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t s ...@@ -1261,7 +1261,7 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t s
// convert vector of 2 fp4 to vector of 2 fp32 // convert vector of 2 fp4 to vector of 2 fp32
template <> template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_scale_t scale, inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x) f4x2_t x)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
...@@ -1281,7 +1281,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_s ...@@ -1281,7 +1281,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_s
// convert vector of 32 fp4 to vector of 32 fp32 // convert vector of 32 fp4 to vector of 32 fp32
template <> template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_scale_t scale, inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x) f4x32_t x)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
...@@ -1450,7 +1450,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m ...@@ -1450,7 +1450,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
// convert fp32 to fp4 // convert fp32 to fp4
template <> template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t scale, float x) inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale)); return f4_convert_sr(x, type_convert<float>(scale));
...@@ -1461,7 +1461,7 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc ...@@ -1461,7 +1461,7 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
// convert vector of 2 fp32 to vector of 2 fp4 // convert vector of 2 fp32 to vector of 2 fp4
template <> template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_scale_t scale, inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
...@@ -1473,7 +1473,7 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_sca ...@@ -1473,7 +1473,7 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_sca
// convert vector of 32 fp32 to vector of 32 fp4 // convert vector of 32 fp32 to vector of 32 fp4
template <> template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_scale_t scale, inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x) float32_t x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::e8m0_scale_t; using ck::e8m0_bexp_t;
using ck::f4_convert_rne; using ck::f4_convert_rne;
using ck::f4_convert_sr; using ck::f4_convert_sr;
using ck::f4_t; using ck::f4_t;
...@@ -90,10 +90,10 @@ TEST(FP4, ScaledConvertFP32Nearest) ...@@ -90,10 +90,10 @@ TEST(FP4, ScaledConvertFP32Nearest)
float max_fp4 = 6.0f; float max_fp4 = 6.0f;
// set maximum scale // set maximum scale
float max_scale = std::pow(2, float max_scale = std::pow(2,
ck::NumericLimits<e8m0_scale_t>::Max().data - ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_scale_t>::bias); // 0xFE -> float ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale // set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_scale_t>::bias); // 0x00 -> float float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0 // set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111 float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds // convert 0 float to fp4 and back with maximal scale, check if holds
...@@ -162,10 +162,10 @@ TEST(FP4, ScaledConvertFP32Stochastic) ...@@ -162,10 +162,10 @@ TEST(FP4, ScaledConvertFP32Stochastic)
float max_fp4 = 6.0f; float max_fp4 = 6.0f;
// set maximum scale // set maximum scale
float max_scale = std::pow(2, float max_scale = std::pow(2,
ck::NumericLimits<e8m0_scale_t>::Max().data - ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_scale_t>::bias); // 0xFE -> float ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale // set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_scale_t>::bias); // 0x00 -> float float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0 // set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111 float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds // convert 0 float to fp4 and back with maximal scale, check if holds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment