Commit 532bbe53 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add fp16 casting functions

parent c1ba7c63
......@@ -8,6 +8,7 @@
namespace ck {
using f8_t = uint8_t;
using half_t = _Float16;
// fp8 rounding modes
enum class f8_rounding_mode
......@@ -16,66 +17,81 @@ enum class f8_rounding_mode
stochastic
};
// cast fp32 to fp8
template <bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
// fp8 exponent/mantissa layout
constexpr int we_f8 = 4;
constexpr int wm_f8 = 3;
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp32 exponent/mantissa layout
constexpr int we_f32 = 8;
constexpr int wm_f32 = 23;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
uint32_t x_bitwise;
x_bitwise = *(reinterpret_cast<uint32_t*>(&x));
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// unpack the input
uint32_t head, mantissa;
int exponent;
uint32_t sign;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> wm_f32) & 0xFF;
sign = head >> (we_f32 + wm_f32);
// convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
uint32_t signed_inf = (sign << (we_f8 + wm_f8)) + (((1 << we_f8) - 1) << wm_f8);
uint32_t drop_mask = (1 << (wm_f32 - wm_f8)) - 1;
int max_exp;
int exp_low_cutoff;
// unpack the input, depends on datatype
if constexpr(is_float)
{
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
if((x_bitwise & 0x7F800000) == 0x7F800000)
return 0x80;
max_exp = (1 << we_f8) - 1;
exp_low_cutoff = 0x80 - (1 << (we_f8 - 1));
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & 0x7F800000) == 0x7F800000)
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
max_exp = (1 << we_f8) - 2;
exp_low_cutoff = 0x80 - (1 << (we_f8 - 1)) + 1;
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (wm_f32 - wm_f8 + 1 - exponent)) - 1;
mantissa += 1 << wm_f32;
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << wm_f32))
if(mantissa >= (2 << type_mant))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (wm_f32 - wm_f8);
mantissa >>= (type_mant - f8_mant);
// check negative exponent
if(exponent <= 0)
{
if(x_bitwise == 0)
......@@ -93,7 +109,7 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{
if(clip)
{
mantissa = (1 << wm_f8) - 1;
mantissa = (1 << f8_mant) - 1;
exponent = max_exp;
}
else
......@@ -101,65 +117,92 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (we_f8 + wm_f8));
mantissa &= (1 << wm_f8) - 1;
return (sign << (we_f8 + wm_f8)) | (exponent << wm_f8) | mantissa;
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << f8_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
// cast fp8 to fp32
template <bool negative_zero_nan>
__host__ __device__ float cast_from_f8(f8_t x)
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int we_f8 = 4;
constexpr int wm_f8 = 3;
// fp32 exponent/mantissa layout
constexpr int we_f32 = 8;
constexpr int wm_f32 = 23;
float fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
if(x == 0)
return static_cast<float>(0);
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes
constexpr uint8_t nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input
uint32_t sign = x >> (we_f8 + wm_f8);
uint32_t mantissa = x & ((1 << wm_f8) - 1);
int exponent = (x & 0x7F) >> wm_f8;
uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant;
int exp_low_cutoff;
uint32_t retval;
constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan)
{
if(x == 0x80)
if(x == nan_code)
return fNaN;
exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1));
}
else
{
if(x == 0x80)
if(x == nan_code)
return fNeg0;
if(exponent == ((1 << we_f8) - 1))
if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1)) + 1;
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + we_f32 + wm_f32) - wm_f8);
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh;
exponent += 1 - sh;
/*
......@@ -169,21 +212,36 @@ __host__ __device__ float cast_from_f8(f8_t x)
exponent--;
}
*/
mantissa &= ((1 << wm_f8) - 1);
mantissa &= ((1 << f8_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wm_f32 - wm_f8;
mantissa <<= type_mant - f8_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wm_f32;
mantissa |= 1 << type_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (we_f32 + wm_f32)) | (exponent << wm_f32) | mantissa;
return *(reinterpret_cast<const float*>(&retval));
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval));
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment