Commit 3d51e246 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update vector_type implementation

parent a92772bf
...@@ -32,10 +32,22 @@ struct bf8_t ...@@ -32,10 +32,22 @@ struct bf8_t
__host__ __device__ bf8_t() = default; __host__ __device__ bf8_t() = default;
}; };
template <typename T>
inline __host__ __device__ constexpr auto is_native()
{
return std::is_same<T, half_t>::value || std::is_same<T, float>::value ||
std::is_same<T, double>::value || std::is_same<T, bhalf_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int8_t>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
// // non_native_vector_type
// template <typename T, index_t N>
// struct non_native_vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of // instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
...@@ -168,6 +180,13 @@ struct scalar_type<f8_t> ...@@ -168,6 +180,13 @@ struct scalar_type<f8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
// utility function for non native vector type
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// //
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -206,7 +225,10 @@ template <typename T> ...@@ -206,7 +225,10 @@ template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
using d2_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using type = d2_t; using type = d2_t;
...@@ -256,8 +278,13 @@ template <typename T> ...@@ -256,8 +278,13 @@ template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using type = d4_t; using type = d4_t;
...@@ -318,9 +345,16 @@ template <typename T> ...@@ -318,9 +345,16 @@ template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using type = d8_t; using type = d8_t;
...@@ -392,10 +426,19 @@ template <typename T> ...@@ -392,10 +426,19 @@ template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using type = d16_t; using type = d16_t;
...@@ -478,11 +521,22 @@ template <typename T> ...@@ -478,11 +521,22 @@ template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using type = d32_t; using type = d32_t;
...@@ -574,12 +628,25 @@ template <typename T> ...@@ -574,12 +628,25 @@ template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using type = d64_t; using type = d64_t;
...@@ -682,13 +749,28 @@ template <typename T> ...@@ -682,13 +749,28 @@ template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
typedef T d128_t __attribute__((ext_vector_type(128))); ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using type = d128_t; using type = d128_t;
...@@ -800,14 +882,31 @@ template <typename T> ...@@ -800,14 +882,31 @@ template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
typedef T d128_t __attribute__((ext_vector_type(128))); ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
typedef T d256_t __attribute__((ext_vector_type(256))); using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using d256_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(256))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(256))))>;
using type = d256_t; using type = d256_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment