"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "36802715112055a70cd57abe858b211dd8b321dd"
Commit 502942fe authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Format

parent 532bbe53
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
namespace ck { namespace ck {
using f8_t = uint8_t; using f8_t = uint8_t;
using half_t = _Float16; using half_t = _Float16;
// fp8 rounding modes // fp8 rounding modes
enum class f8_rounding_mode enum class f8_rounding_mode
...@@ -25,21 +25,22 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -25,21 +25,22 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout // fp8 exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int f8_exp = 4;
constexpr int f8_mant = 3; constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23; constexpr int type_mant = is_half ? 10 : 23;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type T_bitwise; typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
...@@ -58,10 +59,11 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -58,10 +59,11 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
sign = head >> (type_exp + type_mant); sign = head >> (type_exp + type_mant);
} }
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -144,11 +146,11 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -144,11 +146,11 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout // fp8 exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int f8_exp = 4;
constexpr int f8_mant = 3; constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23; constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes // prepare the codes
...@@ -160,10 +162,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -160,10 +162,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr uint16_t ihNegInf = 0xFC00; constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01; constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000; constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf)); fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
} }
else if constexpr(is_float) else if constexpr(is_float)
{ {
...@@ -171,10 +173,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -171,10 +173,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr uint32_t ifNegInf = 0xFF800000; constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001; constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000; constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf)); fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf)); fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN)); fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0)); fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
} }
// unpack the input // unpack the input
...@@ -182,7 +184,8 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -182,7 +184,8 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> f8_mant;
constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment