Commit 9eebd123 authored by Chao Liu's avatar Chao Liu
Browse files

overhaul vector_type, make int8x4_t real vector instead of aliasing from int32_t

parent 5602817f
......@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type;
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
......@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}();
// copy data
vector_type<DstData, DstScalarPerVector> dst_vector;
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type;
using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
});
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(),
dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
p_dst,
dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
......@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
#endif
}
......@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
}
......@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
vector_type<SrcData, SrcScalarPerVector> src_vector;
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
......@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
......@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector.Scalars()[i];
p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
});
constexpr auto move_on_dim = [&]() constexpr
......@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}();
// copy data
vector_type<SrcData, SrcScalarPerVector> src_vector;
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
......@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i];
buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
});
constexpr auto move_on_dim = [&]() constexpr
......@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
vector_type<DstData, DstScalarPerVector> dst_vector;
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}];
dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
});
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type;
using DstVectorType =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector();
dst_vector.template AsType<DstVectorType>()[Number<0>{}];
constexpr auto move_on_dim = [&]() constexpr
{
......
......@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type;
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
......@@ -209,16 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half2_t>::value && (N == 1)) ||
(is_same<T, half4_t>::value && (N == 1)) ||
(is_same<T, half8_t>::value && (N == 1)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)),
"wrong! not implemented");
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
......@@ -241,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{
vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
tmp.AsType<float4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
tmp.AsType<float4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
return tmp.Vector();
return tmp.AsType<float8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, half_t>::value)
......@@ -270,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half8_t>::value)
{
if constexpr(N == 1)
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
tmp.AsType<half4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.Vector();
return tmp.AsType<half8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, int32_t>::value)
......@@ -326,15 +303,15 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{
vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
tmp.AsType<int32x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
tmp.AsType<int32x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
0);
return tmp.Vector();
return tmp.AsType<int32x8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, int8_t>::value)
......@@ -346,44 +323,83 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type<int8_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, int32x2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
return tmp.AsType<int8x8_t>()(Number<0>{});
#else
int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x8_t>(tmp);
#endif
}
}
else if constexpr(is_same<T, int32x4_t>::value)
{
if constexpr(N == 1)
else if constexpr(N == 16)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<2>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<3>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
0);
return tmp.AsType<int8x16_t>()(Number<0>{});
#else
int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x16_t>(tmp);
#endif
}
}
}
......@@ -467,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
#if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
__llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#endif
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
#if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
__llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#endif
}
else if constexpr(N == 8)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
__llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
......@@ -491,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
else if constexpr(N == 16)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
__llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
......@@ -528,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
{
vector_type<half_t, 8> tmp{src_thread_data};
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
......@@ -548,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<T, N>(
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
using vector_t = typename vector_type<T, N>::type;
vector_t tmp =
amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0);
#endif
......@@ -578,26 +613,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wa
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_element_space)
__device__ void
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_element_space)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<T, N>(
amd_buffer_store_impl_v2<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_data_valid)
{
amd_buffer_store_impl_v2<T, N>(
amd_buffer_store_impl_v2<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
......
......@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
......@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
......@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
float& c3)
{
// TODO remove pointer casting
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
......@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
......@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
#endif
}
......@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"v"(as_type<int32_t>(b2)),
"v"(as_type<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
#endif
}
......@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
p_b0_int8x4_t[1],
p_b1_int8x4_t[1],
p_b2_int8x4_t[1],
p_b3_int8x4_t[1],
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
......@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
p_b0_int8x8_t[0],
p_b1_int8x8_t[0],
p_b2_int8x8_t[0],
p_b3_int8x8_t[0],
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
p_b0_int8x8_t[1],
p_b1_int8x8_t[1],
p_b2_int8x8_t[1],
p_b3_int8x8_t[1],
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0,
c1,
c2,
......
......@@ -142,6 +142,11 @@
#define CK_WORKAROUND_SWDEV_275126 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#endif
namespace ck {
enum AddressSpace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment