Commit 803b9db8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Refactor f8_t to add bf8_t

parent 1cf50031
...@@ -12,7 +12,25 @@ using half_t = _Float16; ...@@ -12,7 +12,25 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t;
struct f8_t
{
uint8_t data;
using type = f8_t;
using data_type = uint8_t;
__host__ __device__ f8_t() = default;
__host__ __device__ f8_t(uint8_t init);
};
struct bf8_t
{
uint8_t data;
using type = bf8_t;
using data_type = uint8_t;
__host__ __device__ bf8_t() = default;
};
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -187,8 +205,10 @@ struct vector_type<T, 1> ...@@ -187,8 +205,10 @@ struct vector_type<T, 1>
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
using type = d2_t; using type = d2_t;
...@@ -237,9 +257,11 @@ struct vector_type<T, 2> ...@@ -237,9 +257,11 @@ struct vector_type<T, 2>
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
using type = d4_t; using type = d4_t;
...@@ -299,10 +321,12 @@ struct vector_type<T, 4> ...@@ -299,10 +321,12 @@ struct vector_type<T, 4>
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
using type = d8_t; using type = d8_t;
...@@ -373,11 +397,13 @@ struct vector_type<T, 8> ...@@ -373,11 +397,13 @@ struct vector_type<T, 8>
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T d16_t __attribute__((ext_vector_type(16))); typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
using type = d16_t; using type = d16_t;
...@@ -459,12 +485,14 @@ struct vector_type<T, 16> ...@@ -459,12 +485,14 @@ struct vector_type<T, 16>
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T d16_t __attribute__((ext_vector_type(16))); typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T d32_t __attribute__((ext_vector_type(32))); typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
using type = d32_t; using type = d32_t;
...@@ -555,13 +583,15 @@ struct vector_type<T, 32> ...@@ -555,13 +583,15 @@ struct vector_type<T, 32>
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T d16_t __attribute__((ext_vector_type(16))); typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T d32_t __attribute__((ext_vector_type(32))); typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T d64_t __attribute__((ext_vector_type(64))); typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
using type = d64_t; using type = d64_t;
...@@ -663,14 +693,16 @@ struct vector_type<T, 64> ...@@ -663,14 +693,16 @@ struct vector_type<T, 64>
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T d16_t __attribute__((ext_vector_type(16))); typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T d32_t __attribute__((ext_vector_type(32))); typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T d64_t __attribute__((ext_vector_type(64))); typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T d128_t __attribute__((ext_vector_type(128))); typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
typedef T_adjusted d128_t __attribute__((ext_vector_type(128)));
using type = d128_t; using type = d128_t;
...@@ -781,15 +813,17 @@ struct vector_type<T, 128> ...@@ -781,15 +813,17 @@ struct vector_type<T, 128>
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
{ {
using d1_t = T; using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4))); using d1_t = T_adjusted;
typedef T d8_t __attribute__((ext_vector_type(8))); typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T d16_t __attribute__((ext_vector_type(16))); typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T d32_t __attribute__((ext_vector_type(32))); typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T d64_t __attribute__((ext_vector_type(64))); typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T d128_t __attribute__((ext_vector_type(128))); typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T d256_t __attribute__((ext_vector_type(256))); typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
typedef T_adjusted d128_t __attribute__((ext_vector_type(128)));
typedef T_adjusted d256_t __attribute__((ext_vector_type(256)));
using type = d256_t; using type = d256_t;
...@@ -1013,14 +1047,34 @@ struct NumericLimits<f8_t> ...@@ -1013,14 +1047,34 @@ struct NumericLimits<f8_t>
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static f8_t Min()
{
f8_t x;
x.data=binary_min;
return x;
}
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); } __host__ __device__ static f8_t Max()
{
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); } f8_t x;
x.data=binary_max;
return x;
}
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); } __host__ __device__ static f8_t Lowest()
{
f8_t x;
x.data=binary_lowest;
return x;
}
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); } __host__ __device__ static f8_t QuietNaN()
{
f8_t x;
x.data=binary_qnan;
return x;
}
}; };
} // namespace ck } // namespace ck
...@@ -23,7 +23,7 @@ namespace ck::utils { ...@@ -23,7 +23,7 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
{ {
// check data type // check data type
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<T, half_t>::value;
...@@ -133,7 +133,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -133,7 +133,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
} }
template <typename T, bool negative_zero_nan> template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ T run_cast_from_f8(uint8_t x)
{ {
// check data type // check data type
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<T, half_t>::value;
...@@ -222,7 +222,7 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -222,7 +222,7 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ uint8_t cast_to_f8(T x, uint32_t rng)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<T, half_t>::value;
...@@ -233,7 +233,7 @@ __host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) ...@@ -233,7 +233,7 @@ __host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
} }
template <typename T, bool negative_zero_nan> template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ T cast_from_f8(uint8_t x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<T, half_t>::value;
...@@ -248,3 +248,9 @@ __host__ __device__ T cast_from_f8(f8_t x) ...@@ -248,3 +248,9 @@ __host__ __device__ T cast_from_f8(f8_t x)
} }
} // namespace ck::utils } // namespace ck::utils
// f8_t constuctor impl
inline __host__ __device__ ck::f8_t::f8_t(uint8_t init)
{
data = init;
}
...@@ -106,8 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) ...@@ -106,8 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng));
} }
// convert fp8 to fp32 // convert fp8 to fp32
...@@ -115,7 +115,7 @@ template <> ...@@ -115,7 +115,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x); return utils::cast_from_f8<float, negative_zero_nan>(x.data);
} }
// convert fp16 to fp8 // convert fp16 to fp8
...@@ -126,8 +126,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -126,8 +126,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng));
} }
// convert fp8 to fp16 // convert fp8 to fp16
...@@ -135,7 +135,7 @@ template <> ...@@ -135,7 +135,7 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x); return utils::cast_from_f8<half_t, negative_zero_nan>(x.data);
} }
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
...@@ -209,8 +209,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -209,8 +209,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng));
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
...@@ -223,8 +223,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -223,8 +223,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng));
} }
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment