// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck { __host__ inline int clz(uint32_t x) { return __builtin_clz(x); } __device__ inline int clz(uint32_t x) { return __clz(x); } } // namespace ck namespace ck::utils { // template // __host__ __device__ Y cast_to_f8(X x, uint32_t rng) // { // // check datatypes // constexpr bool is_half = std::is_same::value; // constexpr bool is_float = std::is_same::value; // static_assert(is_half || is_float, "Only half and float can be casted."); // return run_cast_to_f8(x, rng); // } // template // __host__ __device__ Y cast_from_f8(X x) // { // // check datatype // constexpr bool is_half = std::is_same::value; // constexpr bool is_float = std::is_same::value; // static_assert(is_half || is_float, "only half and float are supported."); // return run_cast_from_f8(x); // } template <> __host__ __device__ inline bool is_nan(e8m0_scale_t const scale, f4_t const dataBytes [[maybe_unused]]) { // no need to check for data as it does not have NaN representation return scale == NumericLimits::QuietNaN(); } // no infinity representation in ocp_e2m1_mxfp4 will always return false template <> __host__ __device__ inline bool is_inf(e8m0_scale_t const scale [[maybe_unused]], f4_t const data [[maybe_unused]]) { // no inf representation for ocp_e2m1_mxfp4 return false; } template <> __host__ __device__ inline bool is_zero(e8m0_scale_t const scale, f4_t const data) { if(is_nan(scale, data)) return false; // no need to check for scale as it does not have a 0 representation f4_t data = (data & 0b00001111) & NumericUtils::set_sign_mask; return data == 0b0; } template <> __host__ __device__ inline float to_float(e8m0_scale_t const scale, f4_t const data) { if(is_nan(scale, data)) return std::numeric_limits::quiet_NaN(); if(is_zero(scale, data)) return 0.0f; uint8_t data = data & 0b00001111; int scale_exp = get_exponent_value(scale); return convert_to_float(data, scale_exp); } template <> __host__ __device__ inline f4_t sat_convert_to_type(float value) { cvt t; t.value_float = value; uint32_t sign = t.value_bitwise >> 31; if(std::isnan(value)) { return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; } if(std::abs(value) > NumericLimits::Max()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; f4_t res = convert_to_type(value); if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericUtils::DataMinSubnorm()) return value < 0 ? NumericUtils::negative_zero_mask : NumericUtils::positive_zero_mask; return res; } template <> __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32_t seed) { cvt t; t.value_float = value; uint32_t sign = t.value_bitwise >> 31; if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; if(std::abs(value) > NumericLimits::Max()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; f4_t res = convert_to_type_sr(value, seed); if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericUtils::DataMinSubnorm()) return value < 0 ? NumericUtils::negative_zero_mask : NumericUtils::positive_zero_mask; return res; } } // namespace ck::utils