// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck::utils { template <> __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, f4_t const dataBytes [[maybe_unused]]) { // no need to check for data as it does not have NaN representation return scale == NumericLimits::QuietNaN(); } // no infinity representation in ocp_e2m1_mxfp4 will always return false template <> __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale [[maybe_unused]], f4_t const data [[maybe_unused]]) { // no inf representation for ocp_e2m1_mxfp4 return false; } template <> __host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, f4_t const data) { if(is_nan(scale, data)) return false; // no need to check for scale as it does not have a 0 representation f4_t result = (data & 0b00001111) & NumericUtils::set_sign_mask; return result == 0b0; } template <> __host__ __device__ inline float to_float(e8m0_bexp_t const scale, f4_t const data) { if(is_nan(scale, data)) return std::numeric_limits::quiet_NaN(); if(is_zero(scale, data)) return 0.0f; f4_t prepared_data = data & 0b00001111; int scale_exp = get_exponent_value(scale); return convert_to_float(prepared_data, scale_exp); } template <> __host__ __device__ inline f4_t sat_convert_to_type(float value) { cvt t; t.value_float = value; uint32_t sign = t.value_bitwise >> 31; if(std::isnan(value)) { return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; } if(std::abs(value) > NumericLimits::Max()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; f4_t res = convert_to_type(value); if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) return value < 0 ? NumericUtils::negative_zero_mask : NumericUtils::positive_zero_mask; return res; } template <> __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32_t seed) { cvt t; t.value_float = value; uint32_t sign = t.value_bitwise >> 31; if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; if(std::abs(value) > NumericLimits::Max()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; f4_t res = convert_to_type_sr(value, seed); if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) return value < 0 ? NumericUtils::negative_zero_mask : NumericUtils::positive_zero_mask; return res; } } // namespace ck::utils