Commit 5038b95b authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Split type_convert and cast_to/from_f8

parent f07a74d1
......@@ -14,6 +14,13 @@ using int4_t = _BitInt(4);
#endif
using f8_t = uint8_t;
// fp8 rounding modes
enum class f8_rounding_mode
{
standard,
stochastic
};
// vector_type
template <typename T, index_t N>
struct vector_type;
......@@ -1049,68 +1056,67 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
// cast fp32 to fp8
template <bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ uint8_t cast_to_f8(float x, uint32_t rng)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
// fp8 exponent/mantissa layout
constexpr int we = 4;
constexpr int wm = 3;
constexpr int we_f8 = 4;
constexpr int wm_f8 = 3;
// fp32 exponent/mantissa layout
constexpr int weo = 8;
constexpr int wmo = 23;
// constexpr int we_f32 = 8;
constexpr int wm_f32 = 23;
const int mfmt = 23;
uint32_t _x;
_x = *(reinterpret_cast<uint32_t*>(&x));
uint32_t x_bitwise;
x_bitwise = *(reinterpret_cast<uint32_t*>(&x));
// unpack the input
uint32_t head, mantissa;
int exponent;
uint32_t sign;
head = _x & 0xFF800000;
mantissa = _x & 0x7FFFFF;
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
uint32_t signed_inf = (sign << (we_f8 + wm_f8)) + (((1 << we_f8) - 1) << wm_f8);
if(negative_zero_nan)
{
if((_x & 0x7F800000) == 0x7F800000)
if((x_bitwise & 0x7F800000) == 0x7F800000)
return 0x80;
}
else
{
if((_x & 0x7F800000) == 0x7F800000)
if((x_bitwise & 0x7F800000) == 0x7F800000)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
if(_x == 0)
if(x_bitwise == 0)
return 0;
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
const int exp_low_cutoff = 128 - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
uint32_t drop_mask = (1 << (wm_f32 - wm_f8)) - 1;
const int max_exp = (1 << we_f8) - (negative_zero_nan ? 1 : 2);
const int exp_low_cutoff =
0x80 - (1 << (we_f8 - 1)) + 1 - (negative_zero_nan ? 1 : 0);
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (mfmt - wm + 1 - exponent)) - 1;
mantissa += 1 << mfmt;
drop_mask = (1 << (wm_f32 - wm_f8 + 1 - exponent)) - 1;
mantissa += 1 << wm_f32;
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << mfmt))
if(mantissa >= (2 << wm_f32))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (mfmt - wm);
mantissa >>= (wm_f32 - wm_f8);
if(exponent <= 0)
{
if(_x == 0)
if(x_bitwise == 0)
return 0;
else
{
......@@ -1125,7 +1131,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
if(clip)
{
mantissa = (1 << wm) - 1;
mantissa = (1 << wm_f8) - 1;
exponent = max_exp;
}
else
......@@ -1135,22 +1141,32 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
}
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << 7);
mantissa &= (1 << wm) - 1;
return (sign << 7) | (exponent << wm) | mantissa;
mantissa &= (1 << wm_f8) - 1;
return (sign << 7) | (exponent << wm_f8) | mantissa;
}
// convert fp32 to fp8
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// cast fp8 to fp32
template <bool negative_zero_nan>
__host__ __device__ float cast_from_f8(uint8_t x)
{
// fp8 exponent/mantissa layout
constexpr int we = 4;
constexpr int wm = 3;
constexpr int we_f8 = 4;
constexpr int wm_f8 = 3;
// fp32 exponent/mantissa layout
constexpr int weo = 8;
constexpr int wmo = 23;
constexpr int we_f32 = 8;
constexpr int wm_f32 = 23;
float fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000;
......@@ -1165,10 +1181,11 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
if(x == 0)
return static_cast<float>(0);
uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
// unpack the input
uint32_t sign = x >> (we_f8 + wm_f8);
uint32_t mantissa = x & ((1 << wm_f8) - 1);
int exponent = (x & 0x7F) >> wm_f8;
if(negative_zero_nan)
{
if(x == 0x80)
......@@ -1178,17 +1195,18 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
if(x == 0x80)
return fNeg0;
if(exponent == ((1 << we) - 1))
if(exponent == ((1 << we_f8) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
uint32_t retval;
const int exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
int sh = 1 + __builtin_clz(mantissa) - ((1 + we_f32 + wm_f32) - wm_f8);
mantissa <<= sh;
exponent += 1 - sh;
/*
......@@ -1198,23 +1216,31 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
exponent--;
}
*/
mantissa &= ((1 << wm) - 1);
mantissa &= ((1 << wm_f8) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
mantissa <<= wm_f32 - wm_f8;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo;
mantissa |= 1 << wm_f32;
mantissa >>= 1 - exponent;
exponent = 0;
}
uint32_t retval;
retval = (sign << 31) | (exponent << 23) | mantissa;
retval = (sign << (we_f32 + wm_f32)) | (exponent << wm_f32) | mantissa;
return *(reinterpret_cast<const float*>(&retval));
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<negative_zero_nan>(x);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment