Commit 21481b44 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add fp8<->fp32 type_convert

parent d3929cb0
...@@ -1049,6 +1049,173 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -1049,6 +1049,173 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
// fp8 exponent/mantissa layout
constexpr int we = 4;
constexpr int wm = 3;
// fp32 exponent/mantissa layout
constexpr int weo = 8;
constexpr int wmo = 23;
const int mfmt = 23;
uint32_t _x;
_x = *(reinterpret_cast<uint32_t*>(&x));
uint32_t head, mantissa;
int exponent;
uint32_t sign;
head = _x & 0xFF800000;
mantissa = _x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
if(negative_zero_nan)
{
if((_x & 0x7F800000) == 0x7F800000)
return 0x80;
}
else
{
if((_x & 0x7F800000) == 0x7F800000)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
if(_x == 0)
return 0;
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
const int exp_low_cutoff =
128 - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (mfmt - wm + 1 - exponent)) - 1;
mantissa += 1 << mfmt;
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << mfmt))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (mfmt - wm);
if(exponent <= 0)
{
if(_x == 0)
return 0;
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa >>= 1 - exponent;
exponent = 0;
}
}
// above range: quantize to maximum possible float of the same sign
else if(exponent > max_exp)
{
if(clip)
{
mantissa = (1 << wm) - 1;
exponent = max_exp;
}
else
{
return signed_inf;
}
}
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << 7);
mantissa &= (1 << wm) - 1;
return (sign << 7) | (exponent << wm) | mantissa;
}
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
// fp8 exponent/mantissa layout
constexpr int we = 4;
constexpr int wm = 3;
// fp32 exponent/mantissa layout
constexpr int weo = 8;
constexpr int wmo = 23;
float fInf, fNegInf, fNaN, fNeg0;
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
if(x == 0)
return static_cast<float>(0);
uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if(negative_zero_nan)
{
if(x == 0x80)
return fNaN;
}
else
{
if(x == 0x80)
return fNeg0;
if(exponent == ((1 << we) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
uint32_t retval;
retval = (sign << 31) | (exponent << 23) | mantissa;
return *(reinterpret_cast<const float*>(&retval));
}
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x); __host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment