Unverified Commit e4790c25 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Overhaul vector_type and use real vector for int8x4_t instead of aliasing from int32_t (#29)

* overhaul vector_type, make int8x4_t real vector instead of aliasing from int32_t
parent 3bf52e60
...@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type; using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
...@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}(); }();
// copy data // copy data
vector_type<DstData, DstScalarPerVector> dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type; using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]); dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
...@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>( amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(), dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
p_dst, p_dst,
dst_slice_origin_coord_.GetOffset(), dst_slice_origin_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
...@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
#endif #endif
} }
...@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
} }
...@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// copy data // copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst"); static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
vector_type<SrcData, SrcScalarPerVector> src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type; using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
...@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcAddressSpace == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>( src_vector.template AsType<src_vector_t>()(Number<0>{}) =
p_src, amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_slice_origin_coord_.GetOffset(), p_src,
is_src_valid, src_slice_origin_coord_.GetOffset(),
src_desc.GetElementSpaceSize()); is_src_valid,
src_desc.GetElementSpaceSize());
#else #else
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
else else
{ {
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
...@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector.Scalars()[i]; p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
}); });
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}(); }();
// copy data // copy data
vector_type<SrcData, SrcScalarPerVector> src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type; using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
...@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcAddressSpace == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>( src_vector.template AsType<src_vector_t>()(Number<0>{}) =
p_src, amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_slice_origin_coord_.GetOffset(), p_src,
is_src_valid, src_slice_origin_coord_.GetOffset(),
src_desc.GetElementSpaceSize()); is_src_valid,
src_desc.GetElementSpaceSize());
#else #else
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
else else
{ {
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i]; buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
}); });
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstInMemOp == InMemoryDataOperation::Set, DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write"); "wrong! hardcoded for ds_write");
vector_type<DstData, DstScalarPerVector> dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}]; dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
}); });
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type; using DstVectorType =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) = *reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector(); dst_vector.template AsType<DstVectorType>()[Number<0>{}];
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
......
...@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2 ...@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type; using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
...@@ -6,6 +6,17 @@ ...@@ -6,6 +6,17 @@
namespace ck { namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
......
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
namespace ck { namespace ck {
template <typename T> template <typename T>
union BufferResource union BufferResource_v2
{ {
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data; int32x4_t data;
T* address[2]; StaticallyIndexedArray<T*, 2> address;
int32_t range[4]; StaticallyIndexedArray<int32_t, 4> range;
int32_t config[4]; StaticallyIndexedArray<int32_t, 4> config;
}; };
template <typename T> template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
{ {
BufferResource<T> wave_buffer_resource; BufferResource_v2<T> wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
wave_buffer_resource.address[0] = const_cast<remove_cv_t<T>*>(p_wave); wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
wave_buffer_resource.range[2] = data_space_size * sizeof(T); wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
// wavewise setting (32 bit) // wavewise setting (32 bit)
wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data; return wave_buffer_resource.data;
} }
...@@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, ...@@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int8x2_t
__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
__device__ int8x4_t
__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t __device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, ...@@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, __llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -182,15 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -182,15 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert(
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half2_t>::value && (N == 1)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half4_t>::value && (N == 1)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half8_t>::value && (N == 1)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || "wrong! not implemented");
(is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
{ {
...@@ -213,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -213,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.AsType<float4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = tmp.AsType<float4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, __llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); 0);
return tmp.Vector(); return tmp.AsType<float8_t>()(Number<0>{});
} }
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
...@@ -242,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -242,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return __llvm_amdgcn_raw_buffer_load_fp16x4( return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
} else if constexpr(N == 8)
else if constexpr(is_same<T, half2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half8_t>::value)
{
if constexpr(N == 1)
{ {
vector_type<half_t, 8> tmp; vector_type<half_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = tmp.AsType<half4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, __llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t), src_wave_addr_offset + 4 * sizeof(half_t),
0); 0);
return tmp.Vector(); return tmp.AsType<half8_t>()(Number<0>{});
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -298,32 +303,103 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -298,32 +303,103 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
vector_type<int32_t, 8> tmp; vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( tmp.AsType<int32x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = tmp.AsType<int32x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, __llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t), src_wave_addr_offset + 4 * sizeof(int32_t),
0); 0);
return tmp.AsType<int32x8_t>()(Number<0>{});
return tmp.Vector();
} }
} }
else if constexpr(is_same<T, int32x2_t>::value) else if constexpr(is_same<T, int8_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_i32x2( return __llvm_amdgcn_raw_buffer_load_i8(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
} else if constexpr(N == 2)
else if constexpr(is_same<T, int32x4_t>::value)
{
if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_i32x4( #if !CK_WORKAROUND_SWDEV_XXXXXX
return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
return tmp.AsType<int8x8_t>()(Number<0>{});
#else
int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x8_t>(tmp);
#endif
}
else if constexpr(N == 16)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<2>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<3>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
0);
return tmp.AsType<int8x16_t>()(Number<0>{});
#else
int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x16_t>(tmp);
#endif
} }
} }
} }
...@@ -407,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -407,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data, #if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
__llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#endif
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data, #if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
__llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#endif
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, __llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -431,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -431,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, __llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -468,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -468,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
{ {
vector_type<half_t, 8> tmp{src_thread_data}; vector_type<half_t, 8> tmp{src_thread_data};
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}], __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}], __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(half_t),
...@@ -488,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -488,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// 2) p_src_wave to be a wavewise pointer. // 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave, __device__ typename vector_type_maker<T, N>::type::type
index_t src_thread_data_offset, amd_buffer_load_v2(const T* p_src_wave,
bool src_thread_data_valid, index_t src_thread_data_offset,
index_t src_element_space) bool src_thread_data_valid,
index_t src_element_space)
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space); make_wave_buffer_resource(p_src_wave, src_element_space);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<T, N>( return amd_buffer_load_impl_v2<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
using vector_t = typename vector_type<T, N>::type; vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp =
amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0); return src_thread_data_valid ? tmp : vector_t(0);
#endif #endif
...@@ -518,26 +613,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wa ...@@ -518,26 +613,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wa
// 2) p_dst_wave to be a wavewise pointer. // 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data, __device__ void
T* p_dst_wave, amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
const index_t dst_thread_data_offset, T* p_dst_wave,
const bool dst_thread_data_valid, const index_t dst_thread_data_offset,
const index_t dst_element_space) const bool dst_thread_data_valid,
const index_t dst_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space); make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<T, N>( amd_buffer_store_impl_v2<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
amd_buffer_store_impl_v2<T, N>( amd_buffer_store_impl_v2<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
......
...@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo ...@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
__device__ void __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
...@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float& c2, float& c2,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
...@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, ...@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a); const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0); const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1); const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
...@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, ...@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float& c2, float& c2,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a); const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0); const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1); const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
...@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 ...@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1\n \ v_dot4_i32_i8 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); : "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else #else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
#endif #endif
} }
...@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3\n \ v_dot4_i32_i8 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); : "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"v"(as_type<int32_t>(b2)),
"v"(as_type<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else #else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false); c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false); c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
#endif #endif
} }
...@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a, ...@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t& c2, int32_t& c2,
int32_t& c3) int32_t& c3)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a); amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0); vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1); vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2); vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3); vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[1], amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
p_b0_int8x4_t[1], vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
p_b1_int8x4_t[1], vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
p_b2_int8x4_t[1], vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
p_b3_int8x4_t[1], vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0, c0,
c1, c1,
c2, c2,
...@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t& c3) int32_t& c3)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a); amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0); vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1); vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2); vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3); vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0], amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
p_b0_int8x8_t[0], vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
p_b1_int8x8_t[0], vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
p_b2_int8x8_t[0], vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
p_b3_int8x8_t[0], vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[1], amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
p_b0_int8x8_t[1], vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
p_b1_int8x8_t[1], vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
p_b2_int8x8_t[1], vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
p_b3_int8x8_t[1], vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0, c0,
c1, c1,
c2, c2,
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 0 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
#elif 1 #elif 0
#define CK_AMD_GPU_GFX1030 1 #define CK_AMD_GPU_GFX1030 1
#endif #endif
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...@@ -142,6 +142,11 @@ ...@@ -142,6 +142,11 @@
#define CK_WORKAROUND_SWDEV_275126 1 #define CK_WORKAROUND_SWDEV_275126 1
#endif #endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#endif
namespace ck { namespace ck {
enum AddressSpace enum AddressSpace
......
...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize 64, 16x256x4 // cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 1
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c0_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 64, 16x256x4 // cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor( Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
......
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,9 +62,9 @@ int main(int argc, char* argv[]) ...@@ -62,9 +62,9 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -150,7 +150,7 @@ int main(int argc, char* argv[]) ...@@ -150,7 +150,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment