"...ops/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "c9b69f5ae1e01e350039ff7e362ea2913b27a6e2"
Commit c86e0696 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add non-native vector type

parent 7ad7e3da
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<ComputeTypeA, KPack> a_thread_vec; non_native_vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec; non_native_vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
...@@ -333,9 +333,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -333,9 +333,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}); });
using mfma_input_type_a = using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type; typename non_native_vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b = using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type; typename non_native_vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -947,8 +949,8 @@ struct BlockwiseGemmXdlops_v2 ...@@ -947,8 +949,8 @@ struct BlockwiseGemmXdlops_v2
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatAB, KPack> a_thread_vec; non_native_vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; non_native_vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
...@@ -958,7 +960,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -958,7 +960,7 @@ struct BlockwiseGemmXdlops_v2
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename non_native_vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> ...@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
0, 0,
0); 0);
#else #else
vector_type<f8_t, 8> reg_a_v(reg_a); non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b); non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
0, 0,
0); 0);
#else #else
vector_type<f8_t, 8> reg_a_v(reg_a); non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b); non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> ...@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
0, 0,
0); 0);
#else #else
vector_type<bf8_t, 8> reg_a_v(reg_a); non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b); non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
0, 0,
0); 0);
#else #else
vector_type<bf8_t, 8> reg_a_v(reg_a); non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b); non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> ...@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
0, 0,
0); 0);
#else #else
vector_type<f8_t, 8> reg_a_v(reg_a); non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b); non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
0, 0,
0); 0);
#else #else
vector_type<f8_t, 8> reg_a_v(reg_a); non_native_vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b); non_native_vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
...@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> ...@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
0, 0,
0); 0);
#else #else
vector_type<bf8_t, 8> reg_a_v(reg_a); non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b); non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
...@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
0, 0,
0); 0);
#else #else
vector_type<bf8_t, 8> reg_a_v(reg_a); non_native_vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b); non_native_vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]); float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
......
...@@ -1184,6 +1184,410 @@ struct non_native_vector_type<T, 4> ...@@ -1184,6 +1184,410 @@ struct non_native_vector_type<T, 4>
} }
}; };
template <typename T>
struct non_native_vector_type<T, 8>
{
using Native_vec_ = non_native_vector_base<T, 8>;
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = Native_vec_;
using type = d8_t;
union
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct non_native_vector_type<T, 16>
{
using Native_vec_ = non_native_vector_base<T, 16>;
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = Native_vec_;
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct non_native_vector_type<T, 32>
{
using Native_vec_ = non_native_vector_base<T, 32>;
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = Native_vec_;
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct non_native_vector_type<T, 64>
{
using Native_vec_ = non_native_vector_base<T, 64>;
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using d64_t = Native_vec_;
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr non_native_vector_type() : data_{type{0}} {}
__host__ __device__ constexpr non_native_vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
using int64_t = long; using int64_t = long;
// fp64 // fp64
...@@ -1231,20 +1635,33 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -1231,20 +1635,33 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
using f8x2_t = typename vector_type<f8_t, 2>::type; // using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; // using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; // using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; // using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; // using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; // using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_t = typename non_native_vector_type<f8_t, 2>::type;
using f8x4_t = typename non_native_vector_type<f8_t, 4>::type;
using f8x8_t = typename non_native_vector_type<f8_t, 8>::type;
using f8x16_t = typename non_native_vector_type<f8_t, 16>::type;
using f8x32_t = typename non_native_vector_type<f8_t, 32>::type;
using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
// bf8 // bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type; // using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type; // using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type; // using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; // using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; // using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; // using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_t = typename non_native_vector_type<bf8_t, 2>::type;
using bf8x4_t = typename non_native_vector_type<bf8_t, 4>::type;
using bf8x8_t = typename non_native_vector_type<bf8_t, 8>::type;
using bf8x16_t = typename non_native_vector_type<bf8_t, 16>::type;
using bf8x32_t = typename non_native_vector_type<bf8_t, 32>::type;
using bf8x64_t = typename non_native_vector_type<bf8_t, 64>::type;
// u8 // u8
// i8 // i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type; using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -192,10 +192,18 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0, ...@@ -192,10 +192,18 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0,
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1); z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2); z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
y0 = bit_cast<f8x4_t>(z0); // y0 = bit_cast<f8x4_t>(z0);
y1 = bit_cast<f8x4_t>(z1); // y1 = bit_cast<f8x4_t>(z1);
y2 = bit_cast<f8x4_t>(z2); // y2 = bit_cast<f8x4_t>(z2);
y3 = bit_cast<f8x4_t>(z3); // y3 = bit_cast<f8x4_t>(z3);
std::ignore = z0;
std::ignore = z1;
std::ignore = z2;
std::ignore = z3;
std::ignore = y0;
std::ignore = y1;
std::ignore = y2;
std::ignore = y3;
} }
template <index_t NX, index_t NY> template <index_t NX, index_t NY>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) ...@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = non_native_vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_t, float, negative_zero_nan>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment