Commit 2de38d06 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add type traits, decouple f8/bf8 casting

parent 8f510c03
...@@ -1075,4 +1075,61 @@ struct NumericLimits<bf8_t> ...@@ -1075,4 +1075,61 @@ struct NumericLimits<bf8_t>
}; };
#endif #endif
template <typename T>
struct NumericUtils
{
static constexpr int exp = 0;
static constexpr int mant = 0;
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck } // namespace ck
...@@ -26,52 +26,35 @@ namespace { ...@@ -26,52 +26,35 @@ namespace {
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
constexpr bool is_f8_t = std::is_same<Y, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<Y, bf8_t>::value;
// fp8/bf8 exponent/mantissa layout // fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = is_f8_t ? 4 : 5; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int f8_mant = is_f8_t ? 3 : 2; constexpr int out_mant = NumericUtils<Y>::mant;
// resulting type exponent/mantissa layout // original type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int in_mant = NumericUtils<X>::mant;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr Y nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<X, half_t>::value, uint16_t, uint32_t>::type using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
if constexpr(is_float) head = x_bitwise & NumericUtils<X>::head_mask;
{ mantissa = x_bitwise & NumericUtils<X>::mant_mask;
head = x_bitwise & 0xFF800000; exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
mantissa = x_bitwise & 0x7FFFFF; sign = head >> (in_exp + in_mant);
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant); uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
} uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
else if constexpr(is_half) constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -84,15 +67,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) ...@@ -84,15 +67,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
if(is_half && is_bf8_t && negative_zero_nan && exponent == 0) // if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{ {
exponent += 1; exponent += 1;
while(mantissa < (1 << type_mant)) while(mantissa < (1 << in_mant))
{ {
mantissa <<= 1; mantissa <<= 1;
exponent -= 1; exponent -= 1;
} }
mantissa &= ~(1 << type_mant); mantissa &= ~(1 << in_mant);
} }
// check if x is 0.0 // check if x is 0.0
...@@ -101,16 +86,16 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) ...@@ -101,16 +86,16 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant; mantissa += 1 << in_mant;
// apply random number if needed // apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant)) if(mantissa >= (2 << in_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (type_mant - f8_mant); mantissa >>= (in_mant - out_mant);
// check negative exponent // check negative exponent
if(exponent <= 0) if(exponent <= 0)
...@@ -130,7 +115,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) ...@@ -130,7 +115,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << f8_mant) - 1; mantissa = (1 << out_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -141,81 +126,70 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) ...@@ -141,81 +126,70 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
// check if x is 0.0 or -0.0 // check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << f8_mant) - 1; mantissa &= (1 << out_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
} }
template <typename X, typename Y, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y run_cast_from_f8(X x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // check data type
constexpr bool is_half = std::is_same<Y, half_t>::value; // constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value; // constexpr bool is_float = std::is_same<Y, float>::value;
constexpr bool is_f8_t = std::is_same<X, f8_t>::value; // constexpr bool is_f8_t = std::is_same<X, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<X, bf8_t>::value; // constexpr bool is_bf8_t = std::is_same<X, bf8_t>::value;
// fp8/bf8 exponent/mantissa layout // fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = is_f8_t ? 4 : 5; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int f8_mant = is_f8_t ? 3 : 2; constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes // prepare the codes
constexpr X nan_code = 0x80; constexpr X nan_code = 0x80;
Y fInf, fNegInf, fNaN, fNeg0; Y Inf, NegInf, NaN, Neg0;
if constexpr(is_half) using T_bitwise = typename NumericUtils<Y>::bitwise_type;
{
constexpr uint16_t ihInf = 0x7C00; constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr uint16_t ihNegInf = 0xFC00; constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr uint16_t ihNaN = 0x7C01; constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr uint16_t ihNeg0 = 0x8000; constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
} Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// check if x is 0.0 // check if x is 0.0
if(x == 0) if(x == 0)
return static_cast<Y>(0); return static_cast<Y>(0);
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<Y, half_t>::value, uint16_t, uint32_t>::type retval; T_bitwise retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == nan_code) if(x == nan_code)
return fNaN; return NaN;
} }
else else
{ {
if(x == nan_code) if(x == nan_code)
return fNeg0; return Neg0;
if(exponent == ((1 << f8_exp) - 1)) if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
} }
if(is_bf8_t && is_half && !negative_zero_nan) if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{ {
retval = x; retval = x;
retval <<= 8; retval <<= 8;
...@@ -227,25 +201,25 @@ __host__ __device__ Y run_cast_from_f8(X x) ...@@ -227,25 +201,25 @@ __host__ __device__ Y run_cast_from_f8(X x)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent++; exponent++;
while(mantissa < (1 << f8_mant)) while(mantissa < (1 << in_mant))
{ {
mantissa <<= 1; mantissa <<= 1;
exponent--; exponent--;
} }
mantissa &= ((1 << f8_mant) - 1); mantissa &= ((1 << in_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << type_mant; mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
...@@ -258,9 +232,9 @@ __host__ __device__ Y cast_to_f8(X x, uint32_t rng) ...@@ -258,9 +232,9 @@ __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
constexpr bool is_half = std::is_same<X, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted."); static_assert(is_half || is_float, "Only half and float can be casted.");
constexpr bool is_f8 = std::is_same<Y, f8_t>::value; // constexpr bool is_f8 = std::is_same<Y, f8_t>::value;
constexpr bool is_bf8 = std::is_same<Y, bf8_t>::value; // constexpr bool is_bf8 = std::is_same<Y, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported."); // static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng); return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
...@@ -272,13 +246,9 @@ __host__ __device__ Y cast_from_f8(X x) ...@@ -272,13 +246,9 @@ __host__ __device__ Y cast_from_f8(X x)
constexpr bool is_half = std::is_same<Y, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
constexpr bool is_f8 = std::is_same<X, f8_t>::value; // constexpr bool is_f8 = std::is_same<X, f8_t>::value;
constexpr bool is_bf8 = std::is_same<X, bf8_t>::value; // constexpr bool is_bf8 = std::is_same<X, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported."); // static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
return run_cast_from_f8<X, Y, negative_zero_nan>(x); return run_cast_from_f8<X, Y, negative_zero_nan>(x);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment