Unverified Commit cde0f480 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge pull request #200 from ROCm/lwpck-2390

Enable MXFP4 type
parents b7566434 773c0e70
...@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -11,6 +11,40 @@ namespace ck { ...@@ -11,6 +11,40 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f4_t = unsigned _BitInt(4);
struct e8m0_bexp_t
{
// E8M0 scale is biased
using type = uint8_t;
type data;
constexpr e8m0_bexp_t() : data{type{}} {}
constexpr e8m0_bexp_t(type init) : data{init} {}
bool operator==(const e8m0_bexp_t& other) const { return (data == other.data); }
};
struct f4x2_pk_t
{
using type = uint8_t;
type data;
f4x2_pk_t() : data{type{}} {}
f4x2_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline type unpack() const
{
if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4);
}
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
};
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
...@@ -26,7 +60,7 @@ inline constexpr bool is_native_type() ...@@ -26,7 +60,7 @@ inline constexpr bool is_native_type()
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value || return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value || is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value || is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value; is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value;
} }
// vector_type // vector_type
...@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type; ...@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// f4
using f4x2_t = typename vector_type<f4x2_pk_t, 1>::type;
using f4x4_t = typename vector_type<f4x2_pk_t, 2>::type;
using f4x8_t = typename vector_type<f4x2_pk_t, 4>::type;
using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
...@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t> ...@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t>
} }
}; };
template <>
struct NumericLimits<f4_t>
{
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
static constexpr float data_max_normal_number = 6;
static constexpr float data_min_subnormal_number = 0.5;
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
template <>
struct NumericLimits<e8m0_bexp_t>
{
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
__host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
__host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
{
return e8m0_bexp_t(binary_135);
}
__host__ __device__ static constexpr e8m0_bexp_t Binary_142()
{
return e8m0_bexp_t(binary_142);
}
};
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
{ {
...@@ -2028,6 +2123,7 @@ struct NumericUtils<float> ...@@ -2028,6 +2123,7 @@ struct NumericUtils<float>
static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000; static constexpr uint32_t Neg0 = 0x80000000;
static constexpr bool has_inf = true;
using bitwise_type = uint32_t; using bitwise_type = uint32_t;
}; };
...@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t> ...@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t>
static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000; static constexpr uint32_t Neg0 = 0x8000;
static constexpr bool has_inf = true;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
template <>
struct NumericUtils<bhalf_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template <> template <>
struct NumericUtils<f8_fnuz_t> struct NumericUtils<f8_fnuz_t>
{ {
...@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t> ...@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t>
static constexpr int mant = 3; static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode // static constexpr int bias = 7; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
...@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t> ...@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t>
static constexpr int mant = 2; static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
struct NumericUtils<f8_ocp_t> struct NumericUtils<f8_ocp_t>
...@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t> ...@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t>
}; };
template <> template <>
struct NumericUtils<bhalf_t> struct NumericUtils<f4_t>
{
static constexpr int exp = 2;
static constexpr int mant = 1;
static constexpr int bias = 1;
static constexpr uint32_t sr_shift = 10;
static constexpr int unbiased_exp_min = 0;
static constexpr int unbiased_exp_max = 2;
static constexpr int biased_exp_min = 1;
static constexpr int biased_exp_max = 3;
static constexpr uint8_t positive_zero_mask = 0b0000;
static constexpr uint8_t negative_zero_mask = 0b1000;
static constexpr uint8_t one_mask = 0b0010;
static constexpr uint8_t set_sign_mask = 0b0111;
static constexpr uint8_t data_max_positive_normal_mask = 0b0111;
static constexpr uint8_t data_max_negative_normal_mask = 0b1111;
static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001;
static constexpr bool has_inf = false;
using bitwise_type = uint8_t;
};
template <>
struct NumericUtils<e8m0_bexp_t>
{ {
static constexpr int exp = 8; static constexpr int exp = 8;
static constexpr int mant = 7; static constexpr int mant = 0;
static constexpr int bias = 128; // negative zero nan mode static constexpr int bias = 127;
// static constexpr int bias = 127; // ieee mode
static constexpr int unbiased_exp_min = -127;
static constexpr int unbiased_exp_max = 127;
static constexpr int biased_exp_min = 0;
static constexpr int biased_exp_max = 254;
using bitwise_type = uint8_t;
}; };
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_bexp_t const bexp)
{
// TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(bexp) - NumericUtils<e8m0_bexp_t>::bias);
}
__host__ __device__ inline e8m0_bexp_t cast_from_float(float const scale)
{
uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23);
}
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
x.data >>= NumericUtils<e8m0_bexp_t>::mant;
x.data &= ((1 << NumericUtils<e8m0_bexp_t>::exp) - 1);
return static_cast<int>(x.data);
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
This diff is collapsed.
...@@ -26,6 +26,7 @@ namespace utils { ...@@ -26,6 +26,7 @@ namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0; double compute_error = 0;
...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -450,5 +452,54 @@ check_err(const Range& out, ...@@ -450,5 +452,54 @@ check_err(const Range& out,
return res; return res;
} }
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 0.5,
double atol = 0.5)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
}
return res;
}
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t> ...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_1<ck::f4_t>
{
float value = 1.0;
template <typename... Is>
ck::f4_t operator()(Is...)
{
return ck::type_convert<ck::f4_t>(value);
}
};
template <> template <>
struct GeneratorTensor_1<int8_t> struct GeneratorTensor_1<int8_t>
{ {
...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t> ...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_2<ck::f4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f4_t>(tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t> ...@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_3<ck::f4_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f4_t>(fp32_tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
...@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8) ...@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8)
add_dependencies(test_fp8 test_fp8_fnuz) add_dependencies(test_fp8 test_fp8_fnuz)
add_dependencies(test_fp8 test_bf8_fnuz) add_dependencies(test_fp8 test_bf8_fnuz)
endif() endif()
add_gtest_executable(test_fp4 test_fp4.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp4 PRIVATE utility)
endif()
add_gtest_executable(test_custom_type test_custom_type.cpp) add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment