Commit 3a64757f authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add vector support

parent d44b24d1
......@@ -26,6 +26,7 @@ struct f4x2_pk_t
template <index_t I>
__host__ __device__ inline type unpack() const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return data & 0b00001111;
else
......@@ -38,6 +39,126 @@ struct f4x2_pk_t
}
};
struct f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack()
{
static_assert(I < 16, "Index is out of range.");
union
{
StaticallyIndexedArray_v2<element_type, 3> uint32_array;
f6_t f6_array[16];
} data_union{data};
return data_union.f6_array[I];
}
__host__ __device__ inline type pack(f6_t* x)
{
type* retval = reinterpret_cast<type*>(x);
return *retval;
}
};
struct f6x32_pk_t
{
// store 16 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack()
{
static_assert(I < 32, "Index is out of range.");
union
{
StaticallyIndexedArray_v2<element_type, 6> uint32_array;
f6_t f6_array[32];
} data_union{data};
return data_union.f6_array[I];
}
__host__ __device__ inline type pack(f6_t* x)
{
type* retval = reinterpret_cast<type*>(x);
return *retval;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack()
{
static_assert(I < 16, "Index is out of range.");
union
{
StaticallyIndexedArray_v2<element_type, 3> uint32_array;
bf6_t bf6_array[16];
} data_union{data};
return data_union.bf6_array[I];
}
__host__ __device__ inline type pack(bf6_t* x)
{
type* retval = reinterpret_cast<type*>(x);
return *retval;
}
};
struct bf6x32_pk_t
{
// store 16 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack()
{
static_assert(I < 32, "Index is out of range.");
union
{
StaticallyIndexedArray_v2<element_type, 6> uint32_array;
bf6_t bf6_array[32];
} data_union{data};
return data_union.bf6_array[I];
}
__host__ __device__ inline type pack(bf6_t* x)
{
type* retval = reinterpret_cast<type*>(x);
return *retval;
}
};
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
......@@ -45,7 +166,7 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool, f4_t, f6_t, bf6_t
template <typename T>
inline constexpr bool is_native_type()
{
......@@ -1065,12 +1186,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
......@@ -1171,6 +1317,111 @@ struct non_native_vector_base<
}
};
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
{
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
......@@ -1906,6 +2157,10 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
template <typename T>
struct NumericLimits
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment