Commit dda07ed0 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use vector_type to cover non-native implementation as well

parent fd019e14
...@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
non_native_vector_type<ComputeTypeA, KPack> a_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
non_native_vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
...@@ -333,11 +333,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -333,11 +333,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}); });
using mfma_input_type_a = using mfma_input_type_a =
typename non_native_vector_type<ComputeTypeA, typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b = using mfma_input_type_b =
typename non_native_vector_type<ComputeTypeB, typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -949,8 +947,8 @@ struct BlockwiseGemmXdlops_v2 ...@@ -949,8 +947,8 @@ struct BlockwiseGemmXdlops_v2
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
non_native_vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
non_native_vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
...@@ -960,7 +958,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -960,7 +958,7 @@ struct BlockwiseGemmXdlops_v2
}); });
using mfma_input_type = using mfma_input_type =
typename non_native_vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......
...@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> ...@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
0, 0,
0); 0);
#else #else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a); vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b); vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
0, 0,
0); 0);
#else #else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a); vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b); vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> ...@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
0, 0,
0); 0);
#else #else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a); vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b); vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
0, 0,
0); 0);
#else #else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a); vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b); vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> ...@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
0, 0,
0); 0);
#else #else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a); vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b); vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
0, 0,
0); 0);
#else #else
non_native_vector_type<f8_t, 8> reg_a_v(reg_a); vector_type<f8_t, 8> reg_a_v(reg_a);
non_native_vector_type<bf8_t, 8> reg_b_v(reg_b); vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> ...@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
0, 0,
0); 0);
#else #else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a); vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b); vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
0, 0,
0); 0);
#else #else
non_native_vector_type<bf8_t, 8> reg_a_v(reg_a); vector_type<bf8_t, 8> reg_a_v(reg_a);
non_native_vector_type<f8_t, 8> reg_b_v(reg_b); vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......
...@@ -19,8 +19,18 @@ inline constexpr auto next_pow2(uint32_t x) ...@@ -19,8 +19,18 @@ inline constexpr auto next_pow2(uint32_t x)
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
} }
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, _BitInt(8)>::value ||
is_same<T, unsigned _BitInt(8)>::value || is_same<T, bool>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N, typename Enable = void>
struct vector_type; struct vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
...@@ -177,7 +187,7 @@ struct scalar_type<bool> ...@@ -177,7 +187,7 @@ struct scalar_type<bool>
}; };
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using type = d1_t;
...@@ -211,7 +221,7 @@ struct vector_type<T, 1> ...@@ -211,7 +221,7 @@ struct vector_type<T, 1>
int static err = 0; int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -269,7 +279,7 @@ struct vector_type<T, 2> ...@@ -269,7 +279,7 @@ struct vector_type<T, 2>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -339,7 +349,7 @@ struct vector_type<T, 4> ...@@ -339,7 +349,7 @@ struct vector_type<T, 4>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -421,7 +431,7 @@ struct vector_type<T, 8> ...@@ -421,7 +431,7 @@ struct vector_type<T, 8>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -515,7 +525,7 @@ struct vector_type<T, 16> ...@@ -515,7 +525,7 @@ struct vector_type<T, 16>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -619,7 +629,7 @@ struct vector_type<T, 32> ...@@ -619,7 +629,7 @@ struct vector_type<T, 32>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -735,7 +745,7 @@ struct vector_type<T, 64> ...@@ -735,7 +745,7 @@ struct vector_type<T, 64>
}; };
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -861,7 +871,7 @@ struct vector_type<T, 128> ...@@ -861,7 +871,7 @@ struct vector_type<T, 128>
}; };
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -1013,12 +1023,9 @@ struct non_native_vector_base ...@@ -1013,12 +1023,9 @@ struct non_native_vector_base
T d[N]; T d[N];
}; };
// non-native vector_type // non-native vector_type implementation
template <typename T, index_t N>
struct non_native_vector_type;
template <typename T> template <typename T>
struct non_native_vector_type<T, 1> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 1>; using Native_vec_ = non_native_vector_base<T, 1>;
...@@ -1031,9 +1038,9 @@ struct non_native_vector_type<T, 1> ...@@ -1031,9 +1038,9 @@ struct non_native_vector_type<T, 1>
StaticallyIndexedArray<d1_t, 1> d1x1_; StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1053,7 +1060,7 @@ struct non_native_vector_type<T, 1> ...@@ -1053,7 +1060,7 @@ struct non_native_vector_type<T, 1>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 2> struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 2>; using Native_vec_ = non_native_vector_base<T, 2>;
...@@ -1069,9 +1076,9 @@ struct non_native_vector_type<T, 2> ...@@ -1069,9 +1076,9 @@ struct non_native_vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_; StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1113,7 +1120,7 @@ struct non_native_vector_type<T, 2> ...@@ -1113,7 +1120,7 @@ struct non_native_vector_type<T, 2>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 4> struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 4>; using Native_vec_ = non_native_vector_base<T, 4>;
...@@ -1131,9 +1138,9 @@ struct non_native_vector_type<T, 4> ...@@ -1131,9 +1138,9 @@ struct non_native_vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_; StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1185,7 +1192,7 @@ struct non_native_vector_type<T, 4> ...@@ -1185,7 +1192,7 @@ struct non_native_vector_type<T, 4>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 8> struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 8>; using Native_vec_ = non_native_vector_base<T, 8>;
...@@ -1205,9 +1212,9 @@ struct non_native_vector_type<T, 8> ...@@ -1205,9 +1212,9 @@ struct non_native_vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_; StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1269,7 +1276,7 @@ struct non_native_vector_type<T, 8> ...@@ -1269,7 +1276,7 @@ struct non_native_vector_type<T, 8>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 16> struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 16>; using Native_vec_ = non_native_vector_base<T, 16>;
...@@ -1291,9 +1298,9 @@ struct non_native_vector_type<T, 16> ...@@ -1291,9 +1298,9 @@ struct non_native_vector_type<T, 16>
StaticallyIndexedArray<d16_t, 1> d16x1_; StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1365,7 +1372,7 @@ struct non_native_vector_type<T, 16> ...@@ -1365,7 +1372,7 @@ struct non_native_vector_type<T, 16>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 32> struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 32>; using Native_vec_ = non_native_vector_base<T, 32>;
...@@ -1389,9 +1396,9 @@ struct non_native_vector_type<T, 32> ...@@ -1389,9 +1396,9 @@ struct non_native_vector_type<T, 32>
StaticallyIndexedArray<d32_t, 1> d32x1_; StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1471,7 +1478,7 @@ struct non_native_vector_type<T, 32> ...@@ -1471,7 +1478,7 @@ struct non_native_vector_type<T, 32>
}; };
template <typename T> template <typename T>
struct non_native_vector_type<T, 64> struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using Native_vec_ = non_native_vector_base<T, 64>; using Native_vec_ = non_native_vector_base<T, 64>;
...@@ -1497,9 +1504,9 @@ struct non_native_vector_type<T, 64> ...@@ -1497,9 +1504,9 @@ struct non_native_vector_type<T, 64>
StaticallyIndexedArray<d64_t, 1> d64x1_; StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
...@@ -1641,12 +1648,12 @@ using int8x64_t = typename vector_type<int8_t, 64>::type; ...@@ -1641,12 +1648,12 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
// using f8x16_t = typename vector_type<f8_t, 16>::type; // using f8x16_t = typename vector_type<f8_t, 16>::type;
// using f8x32_t = typename vector_type<f8_t, 32>::type; // using f8x32_t = typename vector_type<f8_t, 32>::type;
// using f8x64_t = typename vector_type<f8_t, 64>::type; // using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_t = typename non_native_vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename non_native_vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename non_native_vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename non_native_vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename non_native_vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename non_native_vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
// bf8 // bf8
// using bf8x2_t = typename vector_type<bf8_t, 2>::type; // using bf8x2_t = typename vector_type<bf8_t, 2>::type;
...@@ -1655,12 +1662,12 @@ using f8x64_t = typename non_native_vector_type<f8_t, 64>::type; ...@@ -1655,12 +1662,12 @@ using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
// using bf8x16_t = typename vector_type<bf8_t, 16>::type; // using bf8x16_t = typename vector_type<bf8_t, 16>::type;
// using bf8x32_t = typename vector_type<bf8_t, 32>::type; // using bf8x32_t = typename vector_type<bf8_t, 32>::type;
// using bf8x64_t = typename vector_type<bf8_t, 64>::type; // using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_t = typename non_native_vector_type<bf8_t, 2>::type; using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename non_native_vector_type<bf8_t, 4>::type; using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename non_native_vector_type<bf8_t, 8>::type; using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename non_native_vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename non_native_vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename non_native_vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8 // i8
......
...@@ -192,18 +192,10 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0, ...@@ -192,18 +192,10 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0,
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1); z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2); z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
// y0 = bit_cast<f8x4_t>(z0); y0 = bit_cast<f8x4_t>(z0);
// y1 = bit_cast<f8x4_t>(z1); y1 = bit_cast<f8x4_t>(z1);
// y2 = bit_cast<f8x4_t>(z2); y2 = bit_cast<f8x4_t>(z2);
// y3 = bit_cast<f8x4_t>(z3); y3 = bit_cast<f8x4_t>(z3);
std::ignore = z0;
std::ignore = z1;
std::ignore = z2;
std::ignore = z3;
std::ignore = y0;
std::ignore = y1;
std::ignore = y2;
std::ignore = y3;
} }
template <index_t NX, index_t NY> template <index_t NX, index_t NY>
......
...@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) ...@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = non_native_vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_t, float, negative_zero_nan>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment