Commit 48d58131 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Rename E8M0 type

parent b1ad4b4f
......@@ -14,15 +14,15 @@ using f4_t = unsigned _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
struct e8m0_scale_t
struct e8m0_bexp_t
{
// E8M0 scale is biased
using type = uint8_t;
type data;
constexpr e8m0_scale_t() : data{type{}} {}
constexpr e8m0_scale_t(type init) : data{init} {}
constexpr e8m0_bexp_t() : data{type{}} {}
constexpr e8m0_bexp_t(type init) : data{init} {}
bool operator==(const e8m0_scale_t& other) const { return (data == other.data); }
bool operator==(const e8m0_bexp_t& other) const { return (data == other.data); }
};
struct f4x2_pk_t
......@@ -1813,33 +1813,30 @@ struct NumericLimits<f4_t>
};
template <>
struct NumericLimits<e8m0_scale_t>
struct NumericLimits<e8m0_bexp_t>
{
static constexpr e8m0_scale_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_scale_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_scale_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_scale_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_scale_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_scale_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_scale_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_scale_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_scale_t Min() { return e8m0_scale_t(binary_min); }
__host__ __device__ static constexpr e8m0_scale_t Max() { return e8m0_scale_t(binary_max); }
__host__ __device__ static constexpr e8m0_scale_t QuietNaN()
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
__host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
__host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
{
return e8m0_scale_t(binary_qnan);
return e8m0_bexp_t(binary_135);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_1() { return e8m0_scale_t(binary_1); }
__host__ __device__ static constexpr e8m0_scale_t Binary_2() { return e8m0_scale_t(binary_2); }
__host__ __device__ static constexpr e8m0_scale_t Binary_3() { return e8m0_scale_t(binary_3); }
__host__ __device__ static constexpr e8m0_scale_t Binary_135()
__host__ __device__ static constexpr e8m0_bexp_t Binary_142()
{
return e8m0_scale_t(binary_135);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_142()
{
return e8m0_scale_t(binary_142);
return e8m0_bexp_t(binary_142);
}
};
......@@ -1944,7 +1941,7 @@ struct NumericUtils<f4_t>
};
template <>
struct NumericUtils<e8m0_scale_t>
struct NumericUtils<e8m0_bexp_t>
{
static constexpr int exp = 8;
static constexpr int mant = 0;
......
......@@ -8,24 +8,24 @@
namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_scale_t const scale)
__host__ __device__ inline float cast_to_float(e8m0_bexp_t const bexp)
{
// TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(scale) - NumericUtils<e8m0_scale_t>::bias);
return std::powf(2, bit_cast<uint8_t>(bexp) - NumericUtils<e8m0_bexp_t>::bias);
}
__host__ __device__ inline e8m0_scale_t cast_from_float(float const scale)
__host__ __device__ inline e8m0_bexp_t cast_from_float(float const scale)
{
uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23);
}
template <>
__host__ __device__ inline int get_exponent_value<e8m0_scale_t>(e8m0_scale_t x)
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
x.data >>= NumericUtils<e8m0_scale_t>::mant;
x.data >>= NumericUtils<e8m0_bexp_t>::mant;
x.data &= ((1 << NumericUtils<e8m0_scale_t>::exp) - 1);
x.data &= ((1 << NumericUtils<e8m0_bexp_t>::exp) - 1);
return static_cast<int>(x.data);
}
......
......@@ -9,16 +9,16 @@
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_scale_t const scale,
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_scale_t>::QuietNaN();
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_unused]],
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
......@@ -26,7 +26,7 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_un
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t const data)
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
......@@ -38,7 +38,7 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t const data)
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
......@@ -48,7 +48,7 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_scale_t>(scale);
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
......@@ -73,7 +73,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) <
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
......@@ -98,7 +98,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) <
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
......
......@@ -18,13 +18,13 @@ inline bool getDataHasInf()
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_scale_t const scale, T const data);
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_scale_t const scale, T const data);
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_scale_t const scale, T const data);
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
......@@ -79,13 +79,13 @@ __host__ __device__ float convert_to_float(T data, int scale_exp)
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_scale_t>::bias))));
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_scale_t const scale, T const data);
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
......
......@@ -1000,7 +1000,7 @@ inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x);
return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
......@@ -1018,8 +1018,8 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(1)),
utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(0))};
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x.unpack(1)),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x.unpack(0))};
return ret;
#endif
}
......@@ -1153,72 +1153,72 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(1));
return float_values.float32_array;
......@@ -1226,24 +1226,24 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
}
template <>
inline __host__ __device__ float type_convert<float, e8m0_scale_t>(e8m0_scale_t scale)
inline __host__ __device__ float type_convert<float, e8m0_bexp_t>(e8m0_bexp_t scale)
{
return utils::cast_to_float(scale);
}
template <>
inline __host__ __device__ e8m0_scale_t type_convert<e8m0_scale_t, float>(float scale)
inline __host__ __device__ e8m0_bexp_t type_convert<e8m0_bexp_t, float>(float scale)
{
return utils::cast_from_float(scale);
}
// Declare a template function for scaled conversion
template <typename Y, typename X>
__host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x);
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t x)
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
......@@ -1261,7 +1261,7 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t s
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_scale_t scale,
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x)
{
#if defined(__gfx950__)
......@@ -1281,7 +1281,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_s
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_scale_t scale,
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x)
{
#if defined(__gfx950__)
......@@ -1450,7 +1450,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t scale, float x)
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
......@@ -1461,7 +1461,7 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_scale_t scale,
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
......@@ -1473,7 +1473,7 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_sca
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_scale_t scale,
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
......
......@@ -5,7 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::e8m0_scale_t;
using ck::e8m0_bexp_t;
using ck::f4_convert_rne;
using ck::f4_convert_sr;
using ck::f4_t;
......@@ -90,10 +90,10 @@ TEST(FP4, ScaledConvertFP32Nearest)
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_scale_t>::Max().data -
ck::NumericUtils<e8m0_scale_t>::bias); // 0xFE -> float
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_scale_t>::bias); // 0x00 -> float
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
......@@ -162,10 +162,10 @@ TEST(FP4, ScaledConvertFP32Stochastic)
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_scale_t>::Max().data -
ck::NumericUtils<e8m0_scale_t>::bias); // 0xFE -> float
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_scale_t>::bias); // 0x00 -> float
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment