Commit 3d51e246 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update vector_type implementation

parent a92772bf
...@@ -32,10 +32,22 @@ struct bf8_t ...@@ -32,10 +32,22 @@ struct bf8_t
__host__ __device__ bf8_t() = default; __host__ __device__ bf8_t() = default;
}; };
template <typename T>
inline __host__ __device__ constexpr auto is_native()
{
return std::is_same<T, half_t>::value || std::is_same<T, float>::value ||
std::is_same<T, double>::value || std::is_same<T, bhalf_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int8_t>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
// // non_native_vector_type
// template <typename T, index_t N>
// struct non_native_vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of // instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
...@@ -168,6 +180,13 @@ struct scalar_type<f8_t> ...@@ -168,6 +180,13 @@ struct scalar_type<f8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
// utility function for non native vector type
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// //
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -205,8 +224,11 @@ struct vector_type<T, 1> ...@@ -205,8 +224,11 @@ struct vector_type<T, 1>
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
using d2_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using type = d2_t; using type = d2_t;
...@@ -255,9 +277,14 @@ struct vector_type<T, 2> ...@@ -255,9 +277,14 @@ struct vector_type<T, 2>
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using type = d4_t; using type = d4_t;
...@@ -317,10 +344,17 @@ struct vector_type<T, 4> ...@@ -317,10 +344,17 @@ struct vector_type<T, 4>
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using type = d8_t; using type = d8_t;
...@@ -391,11 +425,20 @@ struct vector_type<T, 8> ...@@ -391,11 +425,20 @@ struct vector_type<T, 8>
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using type = d16_t; using type = d16_t;
...@@ -477,12 +520,23 @@ struct vector_type<T, 16> ...@@ -477,12 +520,23 @@ struct vector_type<T, 16>
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using type = d32_t; using type = d32_t;
...@@ -573,13 +627,26 @@ struct vector_type<T, 32> ...@@ -573,13 +627,26 @@ struct vector_type<T, 32>
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using type = d64_t; using type = d64_t;
...@@ -681,14 +748,29 @@ struct vector_type<T, 64> ...@@ -681,14 +748,29 @@ struct vector_type<T, 64>
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
typedef T d128_t __attribute__((ext_vector_type(128))); ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using type = d128_t; using type = d128_t;
...@@ -799,15 +881,32 @@ struct vector_type<T, 128> ...@@ -799,15 +881,32 @@ struct vector_type<T, 128>
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); using ext_vect_t = conditional_t<is_native<T>(), T, uint32_t>;
typedef T d4_t __attribute__((ext_vector_type(4))); using d2_t = conditional_t<is_native<T>(),
typedef T d8_t __attribute__((ext_vector_type(8))); ext_vect_t __attribute__((ext_vector_type(2))),
typedef T d16_t __attribute__((ext_vector_type(16))); ext_vect_t __attribute__((ext_vector_type(next_pow2(2))))>;
typedef T d32_t __attribute__((ext_vector_type(32))); using d4_t = conditional_t<is_native<T>(),
typedef T d64_t __attribute__((ext_vector_type(64))); ext_vect_t __attribute__((ext_vector_type(4))),
typedef T d128_t __attribute__((ext_vector_type(128))); ext_vect_t __attribute__((ext_vector_type(next_pow2(4))))>;
typedef T d256_t __attribute__((ext_vector_type(256))); using d8_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(8))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(8))))>;
using d16_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(16))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(16))))>;
using d32_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(32))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(32))))>;
using d64_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(64))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(64))))>;
using d128_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(128))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(128))))>;
using d256_t = conditional_t<is_native<T>(),
ext_vect_t __attribute__((ext_vector_type(256))),
ext_vect_t __attribute__((ext_vector_type(next_pow2(256))))>;
using type = d256_t; using type = d256_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment