Unverified Commit cde0f480 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge pull request #200 from ROCm/lwpck-2390

Enable MXFP4 type
parents b7566434 773c0e70
...@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -11,6 +11,40 @@ namespace ck { ...@@ -11,6 +11,40 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f4_t = unsigned _BitInt(4);
struct e8m0_bexp_t
{
// E8M0 scale is biased
using type = uint8_t;
type data;
constexpr e8m0_bexp_t() : data{type{}} {}
constexpr e8m0_bexp_t(type init) : data{init} {}
bool operator==(const e8m0_bexp_t& other) const { return (data == other.data); }
};
struct f4x2_pk_t
{
using type = uint8_t;
type data;
f4x2_pk_t() : data{type{}} {}
f4x2_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline type unpack() const
{
if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4);
}
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
};
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
...@@ -26,7 +60,7 @@ inline constexpr bool is_native_type() ...@@ -26,7 +60,7 @@ inline constexpr bool is_native_type()
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value || return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value || is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value || is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value; is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value;
} }
// vector_type // vector_type
...@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type; ...@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// f4
using f4x2_t = typename vector_type<f4x2_pk_t, 1>::type;
using f4x4_t = typename vector_type<f4x2_pk_t, 2>::type;
using f4x8_t = typename vector_type<f4x2_pk_t, 4>::type;
using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
...@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t> ...@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t>
} }
}; };
template <>
struct NumericLimits<f4_t>
{
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
static constexpr float data_max_normal_number = 6;
static constexpr float data_min_subnormal_number = 0.5;
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
template <>
struct NumericLimits<e8m0_bexp_t>
{
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
__host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
__host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
{
return e8m0_bexp_t(binary_135);
}
__host__ __device__ static constexpr e8m0_bexp_t Binary_142()
{
return e8m0_bexp_t(binary_142);
}
};
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
{ {
...@@ -2028,6 +2123,7 @@ struct NumericUtils<float> ...@@ -2028,6 +2123,7 @@ struct NumericUtils<float>
static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000; static constexpr uint32_t Neg0 = 0x80000000;
static constexpr bool has_inf = true;
using bitwise_type = uint32_t; using bitwise_type = uint32_t;
}; };
...@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t> ...@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t>
static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000; static constexpr uint32_t Neg0 = 0x8000;
static constexpr bool has_inf = true;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
template <>
struct NumericUtils<bhalf_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template <> template <>
struct NumericUtils<f8_fnuz_t> struct NumericUtils<f8_fnuz_t>
{ {
...@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t> ...@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t>
static constexpr int mant = 3; static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode // static constexpr int bias = 7; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
...@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t> ...@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t>
static constexpr int mant = 2; static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
struct NumericUtils<f8_ocp_t> struct NumericUtils<f8_ocp_t>
...@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t> ...@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t>
}; };
template <> template <>
struct NumericUtils<bhalf_t> struct NumericUtils<f4_t>
{
static constexpr int exp = 2;
static constexpr int mant = 1;
static constexpr int bias = 1;
static constexpr uint32_t sr_shift = 10;
static constexpr int unbiased_exp_min = 0;
static constexpr int unbiased_exp_max = 2;
static constexpr int biased_exp_min = 1;
static constexpr int biased_exp_max = 3;
static constexpr uint8_t positive_zero_mask = 0b0000;
static constexpr uint8_t negative_zero_mask = 0b1000;
static constexpr uint8_t one_mask = 0b0010;
static constexpr uint8_t set_sign_mask = 0b0111;
static constexpr uint8_t data_max_positive_normal_mask = 0b0111;
static constexpr uint8_t data_max_negative_normal_mask = 0b1111;
static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001;
static constexpr bool has_inf = false;
using bitwise_type = uint8_t;
};
template <>
struct NumericUtils<e8m0_bexp_t>
{ {
static constexpr int exp = 8; static constexpr int exp = 8;
static constexpr int mant = 7; static constexpr int mant = 0;
static constexpr int bias = 128; // negative zero nan mode static constexpr int bias = 127;
// static constexpr int bias = 127; // ieee mode
static constexpr int unbiased_exp_min = -127;
static constexpr int unbiased_exp_max = 127;
static constexpr int biased_exp_min = 0;
static constexpr int biased_exp_max = 254;
using bitwise_type = uint8_t;
}; };
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_bexp_t const bexp)
{
// TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(bexp) - NumericUtils<e8m0_bexp_t>::bias);
}
__host__ __device__ inline e8m0_bexp_t cast_from_float(float const scale)
{
uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23);
}
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
x.data >>= NumericUtils<e8m0_bexp_t>::mant;
x.data &= ((1 << NumericUtils<e8m0_bexp_t>::exp) - 1);
return static_cast<int>(x.data);
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/e8m0_utils.hpp"
#include "ck/utility/f8_utils.hpp" #include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
...@@ -583,6 +585,1053 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x) ...@@ -583,6 +585,1053 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif #endif
} }
// convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type<f4_t>(x / scale);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{}, tmp_values{};
// TODO: pack in a loop
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
union
{
float float_array[2];
float2_t float2_array;
} float_values{{x}};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0}, tmp_values{0};
union
{
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{{0}};
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t type_convert<f4x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float scale = 1.0f;
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
float scale = 1.0f;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
return float_values.float32_array;
#endif
}
template <>
inline __host__ __device__ float type_convert<float, e8m0_bexp_t>(e8m0_bexp_t scale)
{
return utils::cast_to_float(scale);
}
template <>
inline __host__ __device__ e8m0_bexp_t type_convert<e8m0_bexp_t, float>(float scale)
{
return utils::cast_from_float(scale);
}
// Declare a template function for scaled conversion
template <typename Y, typename X>
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()),
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
return float_values.float32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
template <typename Y, typename X, std::size_t NumElems> template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x) const std::array<X, NumElems>& x)
......
...@@ -26,6 +26,7 @@ namespace utils { ...@@ -26,6 +26,7 @@ namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0; double compute_error = 0;
...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -450,5 +452,54 @@ check_err(const Range& out, ...@@ -450,5 +452,54 @@ check_err(const Range& out,
return res; return res;
} }
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 0.5,
double atol = 0.5)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
}
return res;
}
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t> ...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_1<ck::f4_t>
{
float value = 1.0;
template <typename... Is>
ck::f4_t operator()(Is...)
{
return ck::type_convert<ck::f4_t>(value);
}
};
template <> template <>
struct GeneratorTensor_1<int8_t> struct GeneratorTensor_1<int8_t>
{ {
...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t> ...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_2<ck::f4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f4_t>(tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t> ...@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_3<ck::f4_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f4_t>(fp32_tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
...@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8) ...@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8)
add_dependencies(test_fp8 test_fp8_fnuz) add_dependencies(test_fp8 test_fp8_fnuz)
add_dependencies(test_fp8 test_bf8_fnuz) add_dependencies(test_fp8 test_bf8_fnuz)
endif() endif()
add_gtest_executable(test_fp4 test_fp4.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp4 PRIVATE utility)
endif()
add_gtest_executable(test_custom_type test_custom_type.cpp) add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f4_convert_rne;
using ck::f4_convert_sr;
using ck::f4_t;
using ck::f4x2_pk_t;
using ck::Number;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::vector_type;
using ck::utils::cast_from_float;
using ck::utils::cast_to_float;
TEST(FP4, NumericLimits)
{
// constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f4_t>::Min(), f4_t{0x2});
EXPECT_EQ(ck::NumericLimits<f4_t>::Max(), f4_t{0x7});
EXPECT_EQ(ck::NumericLimits<f4_t>::Lowest(), f4_t{0xF});
EXPECT_EQ(ck::NumericLimits<f4_t>::MinSubnorm(), f4_t{0x1});
EXPECT_EQ(ck::NumericLimits<f4_t>::MaxSubnorm(), f4_t{0x1});
}
TEST(FP4, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_rne(0.0f)), abs_tol);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_rne(max_fp4)), abs_tol);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_rne(std::numeric_limits<float>::max())), abs_tol);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
// negative norm float value to fp4 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol);
// positive subnorm float value to fp4 and back, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
// negative subnorm float value to fp4 and back, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_rne(neg_float)), abs_tol);
}
TEST(FP4, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f4_convert_sr(0.0f)), abs_tol);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_sr(max_fp4)), abs_tol);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_sr(std::numeric_limits<float>::max())), abs_tol);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative norm float value to fp4 and back, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
// positive subnorm float value to fp4 and back, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_sr(pos_float)), abs_tol);
// negative subnorm float value to fp4 and back, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float, type_convert<float>(f4_convert_sr(neg_float)), abs_tol);
}
TEST(FP4, ScaledConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(0.0f,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(0.0f)),
abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(0.0f,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(0.0f)),
abs_tol);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR(ck::NumericLimits<float>::Min(),
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(max_fp4)),
abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_rne(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_rne(neg_float)),
abs_tol);
}
TEST(FP4, ScaledConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// set maximum fp4 value
float max_fp4 = 6.0f;
// set maximum scale
float max_scale = std::pow(2,
ck::NumericLimits<e8m0_bexp_t>::Max().data -
ck::NumericUtils<e8m0_bexp_t>::bias); // 0xFE -> float
// set minimum scale
float min_scale = std::pow(2, -ck::NumericUtils<e8m0_bexp_t>::bias); // 0x00 -> float
// set arbitrary scale to 256.0
float test_scale = 256.0f; // 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(0.0f)), abs_tol);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR(
0.0f, scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(0.0f)), abs_tol);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR(ck::NumericLimits<float>::Min(),
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(max_fp4)),
abs_tol);
// positive norm float value to fp4 and back with various scales, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative norm float value to fp4 and back with various scales, check if holds
float neg_float = -1.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(neg_float)),
abs_tol);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float = 0.5f;
ASSERT_NEAR(pos_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(pos_float)),
abs_tol);
ASSERT_NEAR(pos_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(pos_float)),
abs_tol);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float = -0.5f;
ASSERT_NEAR(neg_float * test_scale,
scaled_type_convert<float>(cast_from_float(test_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * max_scale,
scaled_type_convert<float>(cast_from_float(max_scale), f4_convert_sr(neg_float)),
abs_tol);
ASSERT_NEAR(neg_float * min_scale,
scaled_type_convert<float>(cast_from_float(min_scale), f4_convert_sr(neg_float)),
abs_tol);
}
TEST(FP4, TestSize)
{
ASSERT_EQ(1, sizeof(f4x2_pk_t));
ASSERT_EQ(1, sizeof(vector_type<f4x2_pk_t, 1>));
ASSERT_EQ(2, sizeof(vector_type<f4x2_pk_t, 2>));
ASSERT_EQ(4, sizeof(vector_type<f4x2_pk_t, 4>));
ASSERT_EQ(8, sizeof(vector_type<f4x2_pk_t, 8>));
ASSERT_EQ(16, sizeof(vector_type<f4x2_pk_t, 16>));
ASSERT_EQ(32, sizeof(vector_type<f4x2_pk_t, 32>));
}
TEST(FP4, TestAlignment)
{
ASSERT_EQ(1, alignof(f4x2_pk_t));
ASSERT_EQ(1, alignof(vector_type<f4x2_pk_t, 1>));
ASSERT_EQ(2, alignof(vector_type<f4x2_pk_t, 2>));
ASSERT_EQ(4, alignof(vector_type<f4x2_pk_t, 4>));
ASSERT_EQ(8, alignof(vector_type<f4x2_pk_t, 8>));
ASSERT_EQ(16, alignof(vector_type<f4x2_pk_t, 16>));
ASSERT_EQ(32, alignof(vector_type<f4x2_pk_t, 32>));
}
// test vector of 1 f4x2_pk_t, contains 2 f4_t
TEST(FP4, TestAsType1)
{
// test size
const int size = 1;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
// test vector of 2 f4x2_pk_t, contains 4 f4_t
TEST(FP4, TestAsType2)
{
// test size
const int size = 2;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
// test vector of 4 f4x2_pk_t, contains 8 f4_t
TEST(FP4, TestAsType4)
{
// test size
const int size = 4;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
// test vector of 8 f4x2_pk_t, contains 16 f4_t
TEST(FP4, TestAsType8)
{
// test size
const int size = 8;
std::vector<f4x2_pk_t::type> test_vec = {f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
// test vector of 16 f4x2_pk_t, contains 32 f4_t
TEST(FP4, TestAsType16)
{
// test size
const int size = 16;
std::vector<f4x2_pk_t::type> test_vec = {
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
// test vector of 32 f4x2_pk_t, contains 64 f4_t
TEST(FP4, TestAsType32)
{
// test size
const int size = 32;
std::vector<f4x2_pk_t::type> test_vec = {
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010},
f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111},
f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001},
f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001},
f4x2_pk_t::type{0b1111}};
// reference vector
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}) =
f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1));
});
// copy the vector
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
test_vec.at(i + 1));
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment