// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once namespace ck::utils { union cvt { float value_float; uint32_t value_bitwise; }; template inline bool getDataHasInf() { return DTYPE::dataInfo.hasInf; } template __host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data); template __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data); template __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data); template __host__ __device__ inline int get_exponent_value(T x) { x >>= NumericUtils::mant; x &= ((1 << NumericUtils::exp) - 1); return static_cast(x); } template __host__ __device__ inline bool is_subnormal(T x) { return get_exponent_value(x) == 0; } template __host__ __device__ inline double get_mantissa_value(T x) { double mantissa = is_subnormal(x) ? 0.0f : 1.0f; for(uint i = 0; i < NumericUtils::mant; i++) { mantissa += std::pow(2, -int32_t((NumericUtils::mant - i))) * (x & 0b1); x >>= 1; } return mantissa; } template __host__ __device__ inline bool get_data_has_inf() { return NumericUtils::has_inf; } template __host__ __device__ float convert_to_float(T data, int scale_exp) { float d_sign = std::pow(-1, static_cast(data >> (NumericUtils::exp + NumericUtils::mant))); float d_exp; if(is_subnormal(data)) d_exp = std::pow(2, 1 - static_cast(NumericUtils::bias)); else d_exp = std::pow(2, get_exponent_value(data) - static_cast(NumericUtils::bias)); float d_mant = get_mantissa_value(data); float data_value = d_sign * d_exp * d_mant; float scale_value = std::pow( 2, static_cast((scale_exp - static_cast(NumericUtils::bias)))); return data_value * scale_value; } template __host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data); template __host__ __device__ T sat_convert_to_type(float value); template __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed); template inline T convert_to_type(float value) { using bitwise_type = typename NumericUtils::bitwise_type; if(std::abs(value) > NumericLimits::Max()) { float max_value = NumericLimits::Max(); cvt t; // cppcheck-suppress redundantAssignment t.value_float = max_value; uint32_t max_bitwise = t.value_bitwise; // cppcheck-suppress redundantAssignment t.value_float = value; bitwise_type sign = t.value_bitwise >> (NumericUtils::exp + NumericUtils::mant); bitwise_type exp = ((max_bitwise >> NumericUtils::mant) & NumericUtils::exp_mask) - (NumericUtils::bias - NumericUtils::bias); bitwise_type mantissa = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); uint32_t mant_prev = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); mant_prev &= ((1 << NumericUtils::mant) - 1); mant_prev--; mant_prev <<= (NumericUtils::mant - NumericUtils::mant); uint32_t prev_bit = ((max_bitwise >> NumericUtils::mant) << NumericUtils::mant) | mant_prev; t.value_bitwise = prev_bit; float prev_val = t.value_float; float diff = max_value - prev_val; float actual_max = max_value + (diff / 2); if(std::abs(value) < actual_max) { return sign << ((NumericUtils::exp + NumericUtils::mant)) | (exp << NumericUtils::mant) | mantissa; } else { if(!get_data_has_inf()) { return (1 << (NumericUtils::mant + NumericUtils::exp)) - 1; } else { exp++; return sign << ((NumericUtils::exp + NumericUtils::mant)) | (exp << NumericUtils::mant); } } } const int mfmt = NumericUtils::mant; uint32_t x; x = bit_cast(value); uint32_t head, mantissa; int32_t exponent, bias; uint32_t sign; head = x & NumericUtils::head_mask; mantissa = x & NumericUtils::mant_mask; exponent = (head >> NumericUtils::mant) & NumericUtils::exp_mask; sign = head >> (NumericUtils::mant + NumericUtils::exp); bias = NumericUtils::bias; if(x == 0) { return 0b0; } const int mini_bias = NumericUtils::bias; const int mini_denormal_act_exponent = 1 - mini_bias; int act_exponent, out_exponent, exponent_diff; bool is_subnorm = false; if(exponent == 0) { act_exponent = exponent - bias + 1; exponent_diff = mini_denormal_act_exponent - act_exponent; is_subnorm = true; } else { act_exponent = exponent - bias; if(act_exponent <= mini_denormal_act_exponent) { exponent_diff = mini_denormal_act_exponent - act_exponent; is_subnorm = true; } else { exponent_diff = 0; } mantissa += (1UL << mfmt); } auto shift_amount = (mfmt - NumericUtils::mant + exponent_diff); shift_amount = (shift_amount >= 64) ? 63 : shift_amount; bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1)); float min_subnorm = NumericLimits::DataMinSubnorm() * (sign ? -1 : 1); if(is_subnorm && std::abs(value) < std::abs(min_subnorm)) { // closer to 0 if(std::abs(value) <= std::abs(min_subnorm - value)) return 0; else return 1 | (sign << (NumericUtils::exp + NumericUtils::mant)); } if(exponent_diff > 0) mantissa >>= exponent_diff; else if(exponent_diff == -1) mantissa <<= -exponent_diff; bool implicit_one = mantissa & (1 << mfmt); out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1); uint32_t drop_mask = (1UL << (mfmt - NumericUtils::mant)) - 1; bool odd = mantissa & (1UL << (mfmt - NumericUtils::mant)); mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask; if(out_exponent == 0) { if((1UL << mfmt) & mantissa) { out_exponent = 1; } } else { if((1UL << (mfmt + 1)) & mantissa) { mantissa >>= 1; out_exponent++; } } mantissa >>= (mfmt - NumericUtils::mant); if(out_exponent == 0 && mantissa == 0) { return 0; } mantissa &= (1UL << NumericUtils::mant) - 1; return (sign << (NumericUtils::exp + NumericUtils::mant)) | (out_exponent << NumericUtils::mant) | mantissa; } template inline T convert_to_type_sr(float value, uint32_t seed) { if(std::abs(value) > NumericLimits::Max()) { float max_value = NumericLimits::Max(); cvt t; // cppcheck-suppress redundantAssignment t.value_float = max_value; uint max_bitwise = t.value_bitwise; // cppcheck-suppress redundantAssignment t.value_float = value; T sign = t.value_bitwise >> (NumericUtils::exp + NumericUtils::mant); T exp = ((max_bitwise >> NumericUtils::mant) & NumericUtils::exp_mask) - (NumericUtils::bias - NumericUtils::bias); uint32_t mant_prev = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); mant_prev &= ((1UL << NumericUtils::mant) - 1); mant_prev--; mant_prev <<= (NumericUtils::mant - NumericUtils::mant); uint32_t prev_bit = ((max_bitwise >> NumericUtils::mant) << NumericUtils::mant) | mant_prev; t.value_bitwise = prev_bit; float prev_val = t.value_float; float diff = max_value - prev_val; float actual_max = max_value + (diff / 2); if(std::abs(value) < actual_max) { double d_max_value = static_cast(max_value); double d_actual_max = static_cast(actual_max); double d_value = static_cast(value); double d_is = std::abs(d_max_value - d_actual_max); double d_seed = static_cast(seed); double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down double thresh = UINT_MAX * d_prob; if(!get_data_has_inf() || d_seed <= thresh) // return static_cast(satConvertToType(getDataMax())); //round down time return sign == 0 ? NumericUtils::data_max_positive_normal_mask : NumericUtils::data_max_negative_normal_mask; else { exp++; return sign << ((NumericUtils::exp + NumericUtils::mant)) // inf | (exp << NumericUtils::mant); } } else { if(!get_data_has_inf()) return (1 << (NumericUtils::mant + NumericUtils::exp)) - 1; else { exp++; return sign << ((NumericUtils::exp + NumericUtils::mant)) // inf | (exp << NumericUtils::mant); } } } uint32_t f32 = bit_cast(value); auto f32_mant = f32 & NumericUtils::mant_mask; auto head = f32 & NumericUtils::head_mask; auto f32_exp = (head >> NumericUtils::mant) & NumericUtils::exp_mask; auto sign_bit = head >> (NumericUtils::mant + NumericUtils::exp); auto sign = sign_bit << (NumericUtils::exp + NumericUtils::mant); f32_exp = static_cast(f32_exp) - NumericUtils::bias; int32_t exp = f32_exp; auto mant = f32_mant; bool subnorm = false; if(f32 == 0) return 0b0; if(exp >= NumericUtils::unbiased_exp_min) { mant = f32_mant; } // if the exponent bit is 8, then the subnormal is exactly the same as f32 else if(exp < NumericUtils::unbiased_exp_min && NumericUtils::exp < NumericUtils::exp) { subnorm = true; auto diff = static_cast(NumericUtils::unbiased_exp_min - exp); if(diff >= 32) { mant = 0; f32_mant = 0; } else { f32_mant |= static_cast(1) << NumericUtils::mant; f32_mant >>= diff; } exp = 0; mant = f32_mant; } uint32_t sr_shift = NumericUtils::sr_shift; // For stochastic-rounding we add the aligned random value to the // mantissa and then truncate (RTZ). mant += seed >> sr_shift; // Increment exponent when mantissa overflows due to rounding if(mant >= static_cast(1) << NumericUtils::mant) ++exp; mant >>= (NumericUtils::mant - NumericUtils::mant); mant &= ((1 << NumericUtils::mant) - 1); auto biased_exp = static_cast(exp); if(!subnorm) biased_exp = static_cast(exp + NumericUtils::bias); biased_exp &= ((1 << NumericUtils::exp) - 1); auto val = sign | biased_exp << NumericUtils::mant | mant; return val; } } // namespace ck::utils