Commit cf7e20a8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update conversions

parent f90f5da6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -1229,6 +1229,7 @@ struct NumericUtils<float> ...@@ -1229,6 +1229,7 @@ struct NumericUtils<float>
static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000; static constexpr uint32_t Neg0 = 0x80000000;
static constexpr bool has_inf = true;
using bitwise_type = uint32_t; using bitwise_type = uint32_t;
}; };
...@@ -1246,6 +1247,7 @@ struct NumericUtils<half_t> ...@@ -1246,6 +1247,7 @@ struct NumericUtils<half_t>
static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000; static constexpr uint32_t Neg0 = 0x8000;
static constexpr bool has_inf = true;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
...@@ -1256,6 +1258,7 @@ struct NumericUtils<f8_t> ...@@ -1256,6 +1258,7 @@ struct NumericUtils<f8_t>
static constexpr int mant = 3; static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode // static constexpr int bias = 7; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
...@@ -1265,14 +1268,15 @@ struct NumericUtils<bf8_t> ...@@ -1265,14 +1268,15 @@ struct NumericUtils<bf8_t>
static constexpr int mant = 2; static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
static constexpr bool has_inf = false;
}; };
template <> template <>
struct NumericUtils<f4_t> struct NumericUtils<f4_t>
{ {
static constexpr int exp = 2; static constexpr int exp = 2;
static constexpr int mant = 1; static constexpr int mant = 1;
static constexpr int bias = 1; static constexpr int bias = 1;
static constexpr uint32_t sr_shift = 10; static constexpr uint32_t sr_shift = 10;
static constexpr int unbiased_exp_min = 0; static constexpr int unbiased_exp_min = 0;
...@@ -1292,6 +1296,8 @@ struct NumericUtils<f4_t> ...@@ -1292,6 +1296,8 @@ struct NumericUtils<f4_t>
static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001;
static constexpr bool has_inf = false;
using bitwise_type = uint8_t; using bitwise_type = uint8_t;
}; };
......
...@@ -6,37 +6,8 @@ ...@@ -6,37 +6,8 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp" #include "ck/utility/mxfp_utils.hpp"
namespace ck {
__host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
__device__ inline int clz(uint32_t x) { return __clz(x); }
} // namespace ck
namespace ck::utils { namespace ck::utils {
// template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
// __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
// {
// // check datatypes
// constexpr bool is_half = std::is_same<X, half_t>::value;
// constexpr bool is_float = std::is_same<X, float>::value;
// static_assert(is_half || is_float, "Only half and float can be casted.");
// return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
// }
// template <typename X, typename Y, bool negative_zero_nan>
// __host__ __device__ Y cast_from_f8(X x)
// {
// // check datatype
// constexpr bool is_half = std::is_same<Y, half_t>::value;
// constexpr bool is_float = std::is_same<Y, float>::value;
// static_assert(is_half || is_float, "only half and float are supported.");
// return run_cast_from_f8<X, Y, negative_zero_nan>(x);
// }
template <> template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_scale_t const scale, __host__ __device__ inline bool is_nan<f4_t>(e8m0_scale_t const scale,
f4_t const dataBytes [[maybe_unused]]) f4_t const dataBytes [[maybe_unused]])
...@@ -61,9 +32,9 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con ...@@ -61,9 +32,9 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con
return false; return false;
// no need to check for scale as it does not have a 0 representation // no need to check for scale as it does not have a 0 representation
f4_t data = (data & 0b00001111) & NumericUtils<e8m0_scale_t>::set_sign_mask; f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return data == 0b0; return result == 0b0;
} }
template <> template <>
...@@ -75,11 +46,11 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c ...@@ -75,11 +46,11 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c
if(is_zero<f4_t>(scale, data)) if(is_zero<f4_t>(scale, data))
return 0.0f; return 0.0f;
uint8_t data = data & 0b00001111; f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_scale_t>(scale); int scale_exp = get_exponent_value<e8m0_scale_t>(scale);
return convert_to_float<f4_t>(data, scale_exp); return convert_to_float<f4_t>(prepared_data, scale_exp);
} }
template <> template <>
...@@ -103,7 +74,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value) ...@@ -103,7 +74,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
f4_t res = convert_to_type<f4_t>(value); f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) < if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) <
NumericUtils<f4_t>::DataMinSubnorm()) NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask; : NumericUtils<f4_t>::positive_zero_mask;
...@@ -128,7 +99,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32 ...@@ -128,7 +99,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
f4_t res = convert_to_type_sr<f4_t>(value, seed); f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) < if(std::abs(to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), res)) <
NumericUtils<f4_t>::DataMinSubnorm()) NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask; : NumericUtils<f4_t>::positive_zero_mask;
......
...@@ -58,11 +58,17 @@ __host__ __device__ inline double get_mantissa_value(T x) ...@@ -58,11 +58,17 @@ __host__ __device__ inline double get_mantissa_value(T x)
return mantissa; return mantissa;
} }
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T> template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp) __host__ __device__ float convert_to_float(T data, int scale_exp)
{ {
float d_sign = float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<t>::mant))); std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp; float d_exp;
if(is_subnormal<T>(data)) if(is_subnormal<T>(data))
...@@ -90,7 +96,7 @@ __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed); ...@@ -90,7 +96,7 @@ __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T> template <typename T>
inline T convert_to_type(float value) inline T convert_to_type(float value)
{ {
using bitwise_type = NumericUtils<T>::bitwise_type; using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max()) if(std::abs(value) > NumericLimits<T>::Max())
{ {
...@@ -197,9 +203,9 @@ inline T convert_to_type(float value) ...@@ -197,9 +203,9 @@ inline T convert_to_type(float value)
shift_amount = (shift_amount >= 64) ? 63 : shift_amount; shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1)); bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T> DataMinSubnorm() * (sign ? -1 : 1); float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(isSubNorm && std::abs(value) < std::abs(min_subnorm)) if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{ {
// closer to 0 // closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value)) if(std::abs(value) <= std::abs(min_subnorm - value))
...@@ -250,7 +256,7 @@ inline T convert_to_type(float value) ...@@ -250,7 +256,7 @@ inline T convert_to_type(float value)
template <typename T> template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed) inline T convert_to_type_sr(float value, uint32_t seed)
{ {
using bitwise_type = NumericUtils<T>::bitwise_type; // using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max()) if(std::abs(value) > NumericLimits<T>::Max())
{ {
...@@ -295,9 +301,8 @@ inline T convert_to_type_sr(float value, uint32_t seed) ...@@ -295,9 +301,8 @@ inline T convert_to_type_sr(float value, uint32_t seed)
if(!get_data_has_inf<T>() || d_seed <= thresh) if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time // return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t> data_max_positive_normal_mask return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t> : NumericUtils<f4_t>::data_max_negative_normal_mask;
data_max_negative_normal_mask;
else else
{ {
exp++; exp++;
...@@ -328,12 +333,12 @@ inline T convert_to_type_sr(float value, uint32_t seed) ...@@ -328,12 +333,12 @@ inline T convert_to_type_sr(float value, uint32_t seed)
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp); auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant); auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = (int32_t)f32exp - NumericUtils<float>::bias; f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp; int32_t exp = f32_exp;
auto mant = f32_mant; auto mant = f32_mant;
bool subnorm = false; bool subnorm = false;
if(value == 0) if(f32 == 0)
return 0b0; return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min) if(exp >= NumericUtils<T>::unbiased_exp_min)
...@@ -345,7 +350,7 @@ inline T convert_to_type_sr(float value, uint32_t seed) ...@@ -345,7 +350,7 @@ inline T convert_to_type_sr(float value, uint32_t seed)
NumericUtils<T>::exp < NumericUtils<float>::exp) NumericUtils<T>::exp < NumericUtils<float>::exp)
{ {
subnorm = true; subnorm = true;
auto diff = (uint32_t)(NumericUtils<T>::unbiased_exp_min - exp); auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32) if(diff >= 32)
{ {
mant = 0; mant = 0;
...@@ -353,7 +358,7 @@ inline T convert_to_type_sr(float value, uint32_t seed) ...@@ -353,7 +358,7 @@ inline T convert_to_type_sr(float value, uint32_t seed)
} }
else else
{ {
f32_mant |= (uint32_t)1 << NumericUtils<float>::mant; f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff; f32_mant >>= diff;
} }
exp = 0; exp = 0;
...@@ -367,14 +372,14 @@ inline T convert_to_type_sr(float value, uint32_t seed) ...@@ -367,14 +372,14 @@ inline T convert_to_type_sr(float value, uint32_t seed)
mant += seed >> sr_shift; mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding // Increment exponent when mantissa overflows due to rounding
if(mant >= (uint32_t)1 << NumericUtils<float>::mant) if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp; ++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant); mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1); mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = (uint32_t)exp; auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm) if(!subnorm)
biased_exp = (uint32_t)(exp + NumericUtils<T>::bias); biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1); biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant; auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val; return val;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp" #include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
...@@ -501,6 +502,56 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -501,6 +502,56 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif #endif
} }
// convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x)
{
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#else rng);
return utils::sat_convert_to_type<f4_t>(x);
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t data)
{
#if defined(__gfx94__)
// float fval;
// uint32_t i32val = static_cast<uint32_t>(x);
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), data);
#endif
}
template <typename Y, typename X, std::size_t NumElems> template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x) const std::array<X, NumElems>& x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment