Commit 803b9db8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Refactor f8_t to add bf8_t

parent 1cf50031
......@@ -12,7 +12,25 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
using f8_t = uint8_t;
struct f8_t
{
uint8_t data;
using type = f8_t;
using data_type = uint8_t;
__host__ __device__ f8_t() = default;
__host__ __device__ f8_t(uint8_t init);
};
struct bf8_t
{
uint8_t data;
using type = bf8_t;
using data_type = uint8_t;
__host__ __device__ bf8_t() = default;
};
// vector_type
template <typename T, index_t N>
......@@ -187,8 +205,10 @@ struct vector_type<T, 1>
template <typename T>
struct vector_type<T, 2>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
using type = d2_t;
......@@ -237,9 +257,11 @@ struct vector_type<T, 2>
template <typename T>
struct vector_type<T, 4>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
using type = d4_t;
......@@ -299,10 +321,12 @@ struct vector_type<T, 4>
template <typename T>
struct vector_type<T, 8>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
using type = d8_t;
......@@ -373,11 +397,13 @@ struct vector_type<T, 8>
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
......@@ -459,12 +485,14 @@ struct vector_type<T, 16>
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
......@@ -555,13 +583,15 @@ struct vector_type<T, 32>
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
......@@ -663,14 +693,16 @@ struct vector_type<T, 64>
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
typedef T_adjusted d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
......@@ -781,15 +813,17 @@ struct vector_type<T, 128>
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
typedef T_adjusted d4_t __attribute__((ext_vector_type(4)));
typedef T_adjusted d8_t __attribute__((ext_vector_type(8)));
typedef T_adjusted d16_t __attribute__((ext_vector_type(16)));
typedef T_adjusted d32_t __attribute__((ext_vector_type(32)));
typedef T_adjusted d64_t __attribute__((ext_vector_type(64)));
typedef T_adjusted d128_t __attribute__((ext_vector_type(128)));
typedef T_adjusted d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
......@@ -1013,14 +1047,34 @@ struct NumericLimits<f8_t>
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static f8_t Min()
{
f8_t x;
x.data=binary_min;
return x;
}
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static f8_t Max()
{
f8_t x;
x.data=binary_max;
return x;
}
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static f8_t Lowest()
{
f8_t x;
x.data=binary_lowest;
return x;
}
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
__host__ __device__ static f8_t QuietNaN()
{
f8_t x;
x.data=binary_qnan;
return x;
}
};
} // namespace ck
......@@ -23,7 +23,7 @@ namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
__host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
......@@ -133,7 +133,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
__host__ __device__ T run_cast_from_f8(uint8_t x)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
......@@ -222,7 +222,7 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
__host__ __device__ uint8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
......@@ -233,7 +233,7 @@ __host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
__host__ __device__ T cast_from_f8(uint8_t x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
......@@ -248,3 +248,9 @@ __host__ __device__ T cast_from_f8(f8_t x)
}
} // namespace ck::utils
// f8_t constuctor impl
inline __host__ __device__ ck::f8_t::f8_t(uint8_t init)
{
data = init;
}
......@@ -106,8 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp8 to fp32
......@@ -115,7 +115,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x);
return utils::cast_from_f8<float, negative_zero_nan>(x.data);
}
// convert fp16 to fp8
......@@ -126,8 +126,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp8 to fp16
......@@ -135,7 +135,7 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
return utils::cast_from_f8<half_t, negative_zero_nan>(x.data);
}
// Declare a template function for bf16 conversion using RTN
......@@ -209,8 +209,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp16 to fp8 with stochastic rounding
......@@ -223,8 +223,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment