Commit dda07ed0 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use vector_type to cover non-native implementation as well

parent fd019e14
......@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
non_native_vector_type<ComputeTypeA, KPack> a_thread_vec;
non_native_vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
......@@ -333,11 +333,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
using mfma_input_type_a =
typename non_native_vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename non_native_vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -949,8 +947,8 @@ struct BlockwiseGemmXdlops_v2
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
non_native_vector_type<FloatAB, KPack> a_thread_vec;
non_native_vector_type<FloatAB, KPack> b_thread_vec;
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
......@@ -960,7 +958,7 @@ struct BlockwiseGemmXdlops_v2
});
using mfma_input_type =
typename non_native_vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......
......@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
0,
0);
#else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
......@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
0,
0);
#else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
......@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
0,
0);
#else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
0,
0);
#else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
0,
0);
#else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
......@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
0,
0);
#else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
......@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
0,
0);
#else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
0,
0);
#else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......
......@@ -19,8 +19,18 @@ inline constexpr auto next_pow2(uint32_t x)
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, _BitInt(8)>::value ||
is_same<T, unsigned _BitInt(8)>::value || is_same<T, bool>::value;
}
// vector_type
template <typename T, index_t N>
template <typename T, index_t N, typename Enable = void>
struct vector_type;
// Caution: DO NOT REMOVE
......@@ -177,7 +187,7 @@ struct scalar_type<bool>
};
template <typename T>
struct vector_type<T, 1>
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
......@@ -211,7 +221,7 @@ struct vector_type<T, 1>
int static err = 0;
template <typename T>
struct vector_type<T, 2>
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -269,7 +279,7 @@ struct vector_type<T, 2>
};
template <typename T>
struct vector_type<T, 4>
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -339,7 +349,7 @@ struct vector_type<T, 4>
};
template <typename T>
struct vector_type<T, 8>
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -421,7 +431,7 @@ struct vector_type<T, 8>
};
template <typename T>
struct vector_type<T, 16>
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -515,7 +525,7 @@ struct vector_type<T, 16>
};
template <typename T>
struct vector_type<T, 32>
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -619,7 +629,7 @@ struct vector_type<T, 32>
};
template <typename T>
struct vector_type<T, 64>
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -735,7 +745,7 @@ struct vector_type<T, 64>
};
template <typename T>
struct vector_type<T, 128>
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -861,7 +871,7 @@ struct vector_type<T, 128>
};
template <typename T>
struct vector_type<T, 256>
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -1013,12 +1023,9 @@ struct non_native_vector_base
T d[N];
};
// non-native vector_type
template <typename T, index_t N>
struct non_native_vector_type;
// non-native vector_type implementation
template <typename T>
struct non_native_vector_type<T, 1>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 1>;
......@@ -1031,9 +1038,9 @@ struct non_native_vector_type<T, 1>
StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1053,7 +1060,7 @@ struct non_native_vector_type<T, 1>
};
template <typename T>
struct non_native_vector_type<T, 2>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 2>;
......@@ -1069,9 +1076,9 @@ struct non_native_vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1113,7 +1120,7 @@ struct non_native_vector_type<T, 2>
};
template <typename T>
struct non_native_vector_type<T, 4>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 4>;
......@@ -1131,9 +1138,9 @@ struct non_native_vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1185,7 +1192,7 @@ struct non_native_vector_type<T, 4>
};
template <typename T>
struct non_native_vector_type<T, 8>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 8>;
......@@ -1205,9 +1212,9 @@ struct non_native_vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1269,7 +1276,7 @@ struct non_native_vector_type<T, 8>
};
template <typename T>
struct non_native_vector_type<T, 16>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 16>;
......@@ -1291,9 +1298,9 @@ struct non_native_vector_type<T, 16>
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1365,7 +1372,7 @@ struct non_native_vector_type<T, 16>
};
template <typename T>
struct non_native_vector_type<T, 32>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 32>;
......@@ -1389,9 +1396,9 @@ struct non_native_vector_type<T, 32>
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1471,7 +1478,7 @@ struct non_native_vector_type<T, 32>
};
template <typename T>
struct non_native_vector_type<T, 64>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
{
using Native_vec_ = non_native_vector_base<T, 64>;
......@@ -1497,9 +1504,9 @@ struct non_native_vector_type<T, 64>
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......@@ -1641,12 +1648,12 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
// using f8x16_t = typename vector_type<f8_t, 16>::type;
// using f8x32_t = typename vector_type<f8_t, 32>::type;
// using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_t = typename non_native_vector_type<f8_t, 2>::type;
using f8x4_t = typename non_native_vector_type<f8_t, 4>::type;
using f8x8_t = typename non_native_vector_type<f8_t, 8>::type;
using f8x16_t = typename non_native_vector_type<f8_t, 16>::type;
using f8x32_t = typename non_native_vector_type<f8_t, 32>::type;
using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
// bf8
// using bf8x2_t = typename vector_type<bf8_t, 2>::type;
......@@ -1655,12 +1662,12 @@ using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
// using bf8x16_t = typename vector_type<bf8_t, 16>::type;
// using bf8x32_t = typename vector_type<bf8_t, 32>::type;
// using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_t = typename non_native_vector_type<bf8_t, 2>::type;
using bf8x4_t = typename non_native_vector_type<bf8_t, 4>::type;
using bf8x8_t = typename non_native_vector_type<bf8_t, 8>::type;
using bf8x16_t = typename non_native_vector_type<bf8_t, 16>::type;
using bf8x32_t = typename non_native_vector_type<bf8_t, 32>::type;
using bf8x64_t = typename non_native_vector_type<bf8_t, 64>::type;
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8
// i8
......
......@@ -192,18 +192,10 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0,
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
// y0 = bit_cast<f8x4_t>(z0);
// y1 = bit_cast<f8x4_t>(z1);
// y2 = bit_cast<f8x4_t>(z2);
// y3 = bit_cast<f8x4_t>(z3);
std::ignore = z0;
std::ignore = z1;
std::ignore = z2;
std::ignore = z3;
std::ignore = y0;
std::ignore = y1;
std::ignore = y2;
std::ignore = y3;
y0 = bit_cast<f8x4_t>(z0);
y1 = bit_cast<f8x4_t>(z1);
y2 = bit_cast<f8x4_t>(z2);
y3 = bit_cast<f8x4_t>(z3);
}
template <index_t NX, index_t NY>
......
......@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = non_native_vector_type<f8_t, 2>(x);
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment