"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "2aae096c2aac382bb9e7851fed14d207bf016f1e"
Commit 2de38d06 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add type traits, decouple f8/bf8 casting

parent 8f510c03
......@@ -1075,4 +1075,61 @@ struct NumericLimits<bf8_t>
};
#endif
template <typename T>
struct NumericUtils
{
static constexpr int exp = 0;
static constexpr int mant = 0;
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck
......@@ -26,52 +26,35 @@ namespace {
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{
// check data type
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
constexpr bool is_f8_t = std::is_same<Y, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<Y, bf8_t>::value;
// fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = is_f8_t ? 4 : 5;
constexpr int f8_mant = is_f8_t ? 3 : 2;
constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int out_mant = NumericUtils<Y>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// original type exponent/mantissa layout
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
int exponent;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise
typedef typename std::conditional<std::is_same<X, half_t>::value, uint16_t, uint32_t>::type
T_bitwise;
using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
if constexpr(is_float)
{
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
head = x_bitwise & NumericUtils<X>::head_mask;
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
......@@ -84,15 +67,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
if(is_half && is_bf8_t && negative_zero_nan && exponent == 0)
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << type_mant))
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << type_mant);
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0
......@@ -101,16 +86,16 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant;
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << in_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant))
if(mantissa >= (2 << in_mant))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (type_mant - f8_mant);
mantissa >>= (in_mant - out_mant);
// check negative exponent
if(exponent <= 0)
......@@ -130,7 +115,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{
if(clip)
{
mantissa = (1 << f8_mant) - 1;
mantissa = (1 << out_mant) - 1;
exponent = max_exp;
}
else
......@@ -141,81 +126,70 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << f8_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y run_cast_from_f8(X x)
{
// check data type
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
constexpr bool is_f8_t = std::is_same<X, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<X, bf8_t>::value;
// constexpr bool is_half = std::is_same<Y, half_t>::value;
// constexpr bool is_float = std::is_same<Y, float>::value;
// constexpr bool is_f8_t = std::is_same<X, f8_t>::value;
// constexpr bool is_bf8_t = std::is_same<X, bf8_t>::value;
// fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = is_f8_t ? 4 : 5;
constexpr int f8_mant = is_f8_t ? 3 : 2;
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes
constexpr X nan_code = 0x80;
Y fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant;
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<Y, half_t>::value, uint16_t, uint32_t>::type retval;
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return fNaN;
return NaN;
}
else
{
if(x == nan_code)
return fNeg0;
if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if(is_bf8_t && is_half && !negative_zero_nan)
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
......@@ -227,25 +201,25 @@ __host__ __device__ Y run_cast_from_f8(X x)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent++;
while(mantissa < (1 << f8_mant))
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << f8_mant) - 1);
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant;
mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << type_mant;
mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
......@@ -258,9 +232,9 @@ __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
constexpr bool is_f8 = std::is_same<Y, f8_t>::value;
constexpr bool is_bf8 = std::is_same<Y, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
// constexpr bool is_f8 = std::is_same<Y, f8_t>::value;
// constexpr bool is_bf8 = std::is_same<Y, bf8_t>::value;
// static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
}
......@@ -272,13 +246,9 @@ __host__ __device__ Y cast_from_f8(X x)
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
constexpr bool is_f8 = std::is_same<X, f8_t>::value;
constexpr bool is_bf8 = std::is_same<X, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// constexpr bool is_f8 = std::is_same<X, f8_t>::value;
// constexpr bool is_bf8 = std::is_same<X, bf8_t>::value;
// static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment