Commit 532bbe53 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add fp16 casting functions

parent c1ba7c63
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace ck { namespace ck {
using f8_t = uint8_t; using f8_t = uint8_t;
using half_t = _Float16;
// fp8 rounding modes // fp8 rounding modes
enum class f8_rounding_mode enum class f8_rounding_mode
...@@ -16,66 +17,81 @@ enum class f8_rounding_mode ...@@ -16,66 +17,81 @@ enum class f8_rounding_mode
stochastic stochastic
}; };
// cast fp32 to fp8 template <typename T, bool negative_zero_nan, bool clip, bool stoch>
template <bool negative_zero_nan, bool clip, bool stoch> __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
__host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{ {
// fp8 exponent/mantissa layout // check data type
constexpr int we_f8 = 4; constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr int wm_f8 = 3; constexpr bool is_float = std::is_same<T, float>::value;
// fp32 exponent/mantissa layout // fp8 exponent/mantissa layout
constexpr int we_f32 = 8; constexpr int f8_exp = 4;
constexpr int wm_f32 = 23; constexpr int f8_mant = 3;
uint32_t x_bitwise; // resulting type exponent/mantissa layout
x_bitwise = *(reinterpret_cast<uint32_t*>(&x)); constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// unpack the input
uint32_t head, mantissa;
int exponent; int exponent;
uint32_t sign; uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
head = x_bitwise & 0xFF800000; // convert to bitwise
mantissa = x_bitwise & 0x7FFFFF; typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type T_bitwise;
exponent = (head >> wm_f32) & 0xFF; T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
sign = head >> (we_f32 + wm_f32);
uint32_t signed_inf = (sign << (we_f8 + wm_f8)) + (((1 << we_f8) - 1) << wm_f8); // unpack the input, depends on datatype
uint32_t drop_mask = (1 << (wm_f32 - wm_f8)) - 1; if constexpr(is_float)
int max_exp; {
int exp_low_cutoff; head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if((x_bitwise & 0x7F800000) == 0x7F800000) if((x_bitwise & nan_mask) == nan_mask)
return 0x80; return nan_code;
max_exp = (1 << we_f8) - 1;
exp_low_cutoff = 0x80 - (1 << (we_f8 - 1));
} }
else else
{ {
if((x_bitwise & 0x7F800000) == 0x7F800000) if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
max_exp = (1 << we_f8) - 2;
exp_low_cutoff = 0x80 - (1 << (we_f8 - 1)) + 1;
} }
// check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (wm_f32 - wm_f8 + 1 - exponent)) - 1; drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << wm_f32; mantissa += 1 << type_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << wm_f32)) if(mantissa >= (2 << type_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (wm_f32 - wm_f8); mantissa >>= (type_mant - f8_mant);
// check negative exponent
if(exponent <= 0) if(exponent <= 0)
{ {
if(x_bitwise == 0) if(x_bitwise == 0)
...@@ -93,7 +109,7 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng) ...@@ -93,7 +109,7 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << wm_f8) - 1; mantissa = (1 << f8_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -101,65 +117,92 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng) ...@@ -101,65 +117,92 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
return signed_inf; return signed_inf;
} }
} }
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (we_f8 + wm_f8)); return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << wm_f8) - 1; mantissa &= (1 << f8_mant) - 1;
return (sign << (we_f8 + wm_f8)) | (exponent << wm_f8) | mantissa; return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
} }
// cast fp8 to fp32 template <typename T, bool negative_zero_nan>
template <bool negative_zero_nan> __host__ __device__ T run_cast_from_f8(f8_t x)
__host__ __device__ float cast_from_f8(f8_t x)
{ {
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout // fp8 exponent/mantissa layout
constexpr int we_f8 = 4; constexpr int f8_exp = 4;
constexpr int wm_f8 = 3; constexpr int f8_mant = 3;
// fp32 exponent/mantissa layout
constexpr int we_f32 = 8;
constexpr int wm_f32 = 23;
float fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
if(x == 0) // resulting type exponent/mantissa layout
return static_cast<float>(0); constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes
constexpr uint8_t nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input // unpack the input
uint32_t sign = x >> (we_f8 + wm_f8); uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << wm_f8) - 1); uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> wm_f8; int exponent = (x & 0x7F) >> f8_mant;
int exp_low_cutoff; constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
uint32_t retval; typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == 0x80) if(x == nan_code)
return fNaN; return fNaN;
exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1));
} }
else else
{ {
if(x == 0x80) if(x == nan_code)
return fNeg0; return fNeg0;
if(exponent == ((1 << we_f8) - 1)) if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1)) + 1;
} }
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + we_f32 + wm_f32) - wm_f8); int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh; mantissa <<= sh;
exponent += 1 - sh; exponent += 1 - sh;
/* /*
...@@ -169,21 +212,36 @@ __host__ __device__ float cast_from_f8(f8_t x) ...@@ -169,21 +212,36 @@ __host__ __device__ float cast_from_f8(f8_t x)
exponent--; exponent--;
} }
*/ */
mantissa &= ((1 << wm_f8) - 1); mantissa &= ((1 << f8_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= wm_f32 - wm_f8; mantissa <<= type_mant - f8_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << wm_f32; mantissa |= 1 << type_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (we_f32 + wm_f32)) | (exponent << wm_f32) | mantissa; retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const float*>(&retval)); return *(reinterpret_cast<const T*>(&retval));
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment