Commit 9ba9ebec authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Fix conversion

parent 846a6773
...@@ -86,9 +86,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) ...@@ -86,9 +86,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
if(is_half && is_bf8_t && negative_zero_nan && exponent == 0) if(is_half && is_bf8_t && negative_zero_nan && exponent == 0)
{ {
exponent += 1; exponent += 1;
int sh = 1 + __builtin_clz(mantissa) - (32 - type_mant); while(mantissa < (1 << type_mant))
mantissa <<= sh; {
exponent -= sh; mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << type_mant); mantissa &= ~(1 << type_mant);
} }
...@@ -150,6 +152,7 @@ __host__ __device__ Y run_cast_from_f8(X x) ...@@ -150,6 +152,7 @@ __host__ __device__ Y run_cast_from_f8(X x)
constexpr bool is_half = std::is_same<Y, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
constexpr bool is_f8_t = std::is_same<X, f8_t>::value; constexpr bool is_f8_t = std::is_same<X, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<X, bf8_t>::value;
// fp8/bf8 exponent/mantissa layout // fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = is_f8_t ? 4 : 5; constexpr int f8_exp = is_f8_t ? 4 : 5;
...@@ -185,6 +188,10 @@ __host__ __device__ Y run_cast_from_f8(X x) ...@@ -185,6 +188,10 @@ __host__ __device__ Y run_cast_from_f8(X x)
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0)); fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
} }
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << f8_mant) - 1);
...@@ -207,14 +214,24 @@ __host__ __device__ Y run_cast_from_f8(X x) ...@@ -207,14 +214,24 @@ __host__ __device__ Y run_cast_from_f8(X x)
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
} }
if(is_bf8_t && is_half && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); exponent++;
mantissa <<= sh; while(mantissa < (1 << f8_mant))
{
mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << f8_mant) - 1); mantissa &= ((1 << f8_mant) - 1);
exponent += 1 - sh;
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= type_mant - f8_mant;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment