Commit 923c1700 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add bf8 conversion methods

parent 2776c177
...@@ -22,16 +22,18 @@ namespace ck::utils { ...@@ -22,16 +22,18 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type // check data type
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
constexpr bool is_f8_t = std::is_same<Y, f8_t>::value;
constexpr bool is_bf8_t = std::is_same<Y, bf8_t>::value;
// fp8 exponent/mantissa layout // fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int f8_exp = is_f8_t ? 4 : 5;
constexpr int f8_mant = 3; constexpr int f8_mant = is_f8_t ? 3 : 2;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int type_exp = is_half ? 5 : 8;
...@@ -40,11 +42,11 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -40,11 +42,11 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type typedef typename std::conditional<std::is_same<X, half_t>::value, uint16_t, uint32_t>::type
T_bitwise; T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
...@@ -81,6 +83,15 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -81,6 +83,15 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
if(is_half && is_bf8_t && negative_zero_nan && exponent == 0)
{
exponent += 1;
int sh = 1 + __builtin_clz(mantissa) - (32 - type_mant);
mantissa <<= sh;
exponent -= sh;
mantissa &= ~(1 << type_mant);
}
// check if x is 0.0 // check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
...@@ -132,24 +143,25 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -132,24 +143,25 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(uint8_t x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // check data type
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
constexpr bool is_f8_t = std::is_same<X, f8_t>::value;
// fp8 exponent/mantissa layout // fp8/bf8 exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int f8_exp = is_f8_t ? 4 : 5;
constexpr int f8_mant = 3; constexpr int f8_mant = is_f8_t ? 3 : 2;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23; constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes // prepare the codes
constexpr uint8_t nan_code = 0x80; constexpr X nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0; Y fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half) if constexpr(is_half)
{ {
constexpr uint16_t ihInf = 0x7C00; constexpr uint16_t ihInf = 0x7C00;
...@@ -180,7 +192,7 @@ __host__ __device__ T run_cast_from_f8(uint8_t x) ...@@ -180,7 +192,7 @@ __host__ __device__ T run_cast_from_f8(uint8_t x)
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; typename std::conditional<std::is_same<Y, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -216,38 +228,41 @@ __host__ __device__ T run_cast_from_f8(uint8_t x) ...@@ -216,38 +228,41 @@ __host__ __device__ T run_cast_from_f8(uint8_t x)
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ uint8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype // check datatypes
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8."); static_assert(is_half || is_float, "Only half and float can be casted.");
constexpr bool is_f8 = std::is_same<Y, f8_t>::value;
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng); constexpr bool is_bf8 = std::is_same<Y, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(uint8_t x) __host__ __device__ Y cast_from_f8(X x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
constexpr bool is_f8 = std::is_same<X, f8_t>::value;
constexpr bool is_bf8 = std::is_same<X, bf8_t>::value;
static_assert(is_f8 || is_bf8, "Casting to f8 and bf8 only is supported.");
// check if x is 0.0 // check if x is 0.0
if(x == 0) if(x == 0)
return static_cast<T>(0); return static_cast<Y>(0);
return run_cast_from_f8<T, negative_zero_nan>(x); return run_cast_from_f8<X, Y, negative_zero_nan>(x);
} }
} // namespace ck::utils } // namespace ck::utils
// f8_t constuctor impl
inline __host__ __device__ ck::f8_t::f8_t(uint8_t init) { data = init; }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment