Commit 2776c177 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add bf8, use BitInt types

parent 9967360c
...@@ -12,25 +12,8 @@ using half_t = _Float16; ...@@ -12,25 +12,8 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = _BitInt(8);
struct f8_t using bf8_t = unsigned _BitInt(8);
{
uint8_t data;
using type = f8_t;
using data_type = uint8_t;
__host__ __device__ f8_t() = default;
__host__ __device__ f8_t(uint8_t init);
};
struct bf8_t
{
uint8_t data;
using type = bf8_t;
using data_type = uint8_t;
__host__ __device__ bf8_t() = default;
};
template <typename T> template <typename T>
inline __host__ __device__ constexpr auto is_native() inline __host__ __device__ constexpr auto is_native()
...@@ -44,10 +27,6 @@ inline __host__ __device__ constexpr auto is_native() ...@@ -44,10 +27,6 @@ inline __host__ __device__ constexpr auto is_native()
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
// // non_native_vector_type
// template <typename T, index_t N>
// struct non_native_vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of // instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
...@@ -180,14 +159,13 @@ struct scalar_type<f8_t> ...@@ -180,14 +159,13 @@ struct scalar_type<f8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
// utility function for non native vector type template <>
inline constexpr auto next_pow2(uint32_t x) struct scalar_type<bf8_t>
{ {
// Precondition: x > 1. using type = bf8_t;
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; static constexpr index_t vector_size = 1;
} };
//
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
...@@ -225,10 +203,7 @@ template <typename T> ...@@ -225,10 +203,7 @@ template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using type = d2_t; using type = d2_t;
...@@ -278,13 +253,8 @@ template <typename T> ...@@ -278,13 +253,8 @@ template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using type = d4_t; using type = d4_t;
...@@ -345,16 +315,9 @@ template <typename T> ...@@ -345,16 +315,9 @@ template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using type = d8_t; using type = d8_t;
...@@ -426,19 +389,10 @@ template <typename T> ...@@ -426,19 +389,10 @@ template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>; typedef T d16_t __attribute__((ext_vector_type(16)));
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using type = d16_t; using type = d16_t;
...@@ -521,22 +475,11 @@ template <typename T> ...@@ -521,22 +475,11 @@ template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>; typedef T d16_t __attribute__((ext_vector_type(16)));
using d4_t = conditional_t<is_native<T>(), typedef T d32_t __attribute__((ext_vector_type(32)));
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using type = d32_t; using type = d32_t;
...@@ -628,25 +571,12 @@ template <typename T> ...@@ -628,25 +571,12 @@ template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>; typedef T d16_t __attribute__((ext_vector_type(16)));
using d4_t = conditional_t<is_native<T>(), typedef T d32_t __attribute__((ext_vector_type(32)));
ext_vect_t __attribute__((ext_vector_type(4))), typedef T d64_t __attribute__((ext_vector_type(64)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using type = d64_t; using type = d64_t;
...@@ -749,28 +679,13 @@ template <typename T> ...@@ -749,28 +679,13 @@ template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>; typedef T d16_t __attribute__((ext_vector_type(16)));
using d4_t = conditional_t<is_native<T>(), typedef T d32_t __attribute__((ext_vector_type(32)));
ext_vect_t __attribute__((ext_vector_type(4))), typedef T d64_t __attribute__((ext_vector_type(64)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>; typedef T d128_t __attribute__((ext_vector_type(128)));
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using type = d128_t; using type = d128_t;
...@@ -882,31 +797,14 @@ template <typename T> ...@@ -882,31 +797,14 @@ template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
{ {
using d1_t = T; using d1_t = T;
using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>; typedef T d2_t __attribute__((ext_vector_type(2)));
using d2_t = conditional_t<is_native<T>(), typedef T d4_t __attribute__((ext_vector_type(4)));
ext_vect_t __attribute__((ext_vector_type(2))), typedef T d8_t __attribute__((ext_vector_type(8)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>; typedef T d16_t __attribute__((ext_vector_type(16)));
using d4_t = conditional_t<is_native<T>(), typedef T d32_t __attribute__((ext_vector_type(32)));
ext_vect_t __attribute__((ext_vector_type(4))), typedef T d64_t __attribute__((ext_vector_type(64)));
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>; typedef T d128_t __attribute__((ext_vector_type(128)));
using d8_t = conditional_t<is_native<T>(), typedef T d256_t __attribute__((ext_vector_type(256)));
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using d256_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(256))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(256))))>;
using type = d256_t; using type = d256_t;
...@@ -1077,6 +975,14 @@ using f8x16_t = typename vector_type<f8_t, 16>::type; ...@@ -1077,6 +975,14 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
...@@ -1126,38 +1032,47 @@ struct NumericLimits<int4_t> ...@@ -1126,38 +1032,47 @@ struct NumericLimits<int4_t>
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000 static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static f8_t Min() __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
{
f8_t x;
x.data = binary_min;
return x;
}
__host__ __device__ static f8_t Max() __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
{
f8_t x;
x.data = binary_max;
return x;
}
__host__ __device__ static f8_t Lowest() __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
{
f8_t x;
x.data = binary_lowest;
return x;
}
__host__ __device__ static f8_t QuietNaN() __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
{ };
f8_t x;
x.data = binary_qnan; template <>
return x; struct NumericLimits<bf8_t>
} {
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment