Commit 61e09a18 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add non_native_vector_type

parent ac58cc5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -13,6 +13,12 @@ using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// vector_type
template <typename T, index_t N>
struct vector_type;
......@@ -990,6 +996,194 @@ struct vector_type<T, 256>
}
};
template <typename T, index_t N>
struct non_native_vector_base
{
using BoolVecT = non_native_vector_base<bool, N>;
using VecT = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const VecT&) = default;
__host__ __device__ non_native_vector_base(VecT&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type
template <typename T, index_t N>
struct non_native_vector_type;
template <typename T>
struct non_native_vector_type<T, 1>
{
using Native_vec_ = non_native_vector_base<T, 1>;
using d1_t = Native_vec_;
using type = d1_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
};
template <typename T>
struct non_native_vector_type<T, 2>
{
using Native_vec_ = non_native_vector_base<T, 2>;
using d1_t = T;
using d2_t = Native_vec_;
using type = d2_t;
union alignas(next_pow2(2 * sizeof(T)))
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct non_native_vector_type<T, 4>
{
using Native_vec_ = non_native_vector_base<T, 4>;
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = Native_vec_;
using type = d4_t;
union alignas(next_pow2(4 * sizeof(T)))
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
using int64_t = long;
// fp64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment