Commit df83690d authored by Chao Liu's avatar Chao Liu
Browse files

update fwd-v5r1 to use DynamicBuffer

parent 01055d95
...@@ -79,24 +79,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -79,24 +79,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
template <typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
} }
} }
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, p_dst); threadwise_transfer_.RunWrite(dst_desc, dst_buf);
} }
} }
......
...@@ -84,6 +84,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,6 +84,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
auto a_global_buf = make_dynamic_buffer(p_a_global);
auto b_global_buf = make_dynamic_buffer(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global);
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0); // const auto E = a_e_k_global_desc.GetLength(I0);
...@@ -255,16 +259,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -255,16 +259,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -282,7 +286,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -282,7 +286,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -298,7 +302,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -298,7 +302,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
...@@ -321,7 +325,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -321,7 +325,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -370,7 +374,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -370,7 +374,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
} }
......
...@@ -83,12 +83,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -83,12 +83,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks> template <typename SrcSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer,
typename DstIteratorHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstData* p_dst, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
...@@ -214,7 +217,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -214,7 +217,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>( amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.template AsType<dst_vector_t>()(Number<0>{}), dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
p_dst, dst_buf.p_data_,
dst_slice_origin_coord_.GetOffset(), dst_slice_origin_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_desc.GetElementSpaceSize()); dst_desc.GetElementSpaceSize());
...@@ -222,7 +225,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -222,7 +225,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = &(dst_buf.p_data_[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]; dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
#endif #endif
...@@ -232,7 +235,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -232,7 +235,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = &(dst_buf.p_data_[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]; dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
} }
...@@ -283,7 +286,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -283,7 +286,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
} }
} }
__device__ void Run(const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst) template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
...@@ -293,7 +301,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -293,7 +301,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(p_src, dst_desc, p_dst, dst_iterator_hacks); Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
...@@ -379,10 +387,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -379,10 +387,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}; // namespace ck }; // namespace ck
// Assume: // Assume:
// 1. src_desc is not known at compile-time // 1. src:
// 2. dst_desc is known at compile-time // 1. SrcDesc is not known at compile-time
// 3. src_slice_origin_idx is not known at compile-time // 2. SrcBuffer is DynamicBuffer
// 4. dst_slice_origin_idx is known at compile-time and it's 0 // 3. src_slice_origin_idx is not known at compile-time
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -419,9 +431,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -419,9 +431,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename DstBuffer, typename DstSliceOriginIdx, typename SrcIteratorHacks> template <typename SrcBuffer,
typename DstBuffer,
typename DstSliceOriginIdx,
typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf, DstBuffer& dst_buf,
...@@ -541,14 +556,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -541,14 +556,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.template AsType<src_vector_t>()(Number<0>{}) = src_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>( amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src, src_buf.p_data_,
src_slice_origin_coord_.GetOffset(), src_slice_origin_coord_.GetOffset(),
is_src_valid, is_src_valid,
src_desc.GetElementSpaceSize()); src_desc.GetElementSpaceSize());
#else #else
src_vector.template AsType<src_vector_t>()(Number<0>{}) = src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
...@@ -556,7 +571,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -556,7 +571,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
{ {
src_vector.template AsType<src_vector_t>()(Number<0>{}) = src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
...@@ -614,9 +629,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -614,9 +629,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
} }
} }
template <typename DstBuffer, typename DstSliceOriginIdx> template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf) DstBuffer& dst_buf)
...@@ -629,7 +644,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -629,7 +644,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -716,8 +731,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -716,8 +731,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// Assume: // Assume:
// 1. src_desc and dst_desc are not known at compile-time // 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time, // 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. Use thread buffer // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
typename SrcData, typename SrcData,
...@@ -780,11 +796,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -780,11 +796,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -882,14 +902,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -882,14 +902,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>( amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src, src_buf.p_data_,
src_slice_origin_coord_.GetOffset(), src_slice_origin_coord_.GetOffset(),
is_src_valid, is_src_valid,
src_desc.GetElementSpaceSize()); src_desc.GetElementSpaceSize());
#else #else
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
...@@ -897,7 +917,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -897,7 +917,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
{ {
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
...@@ -958,10 +978,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -958,10 +978,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
} }
} }
template <typename DstIteratorHacks> template <typename DstBuffer, typename DstIteratorHacks>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc,
RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks) DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1070,8 +1095,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1070,8 +1095,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
using dst_vector_t = typename decltype(dst_tmp_vector)::type; using dst_vector_t = typename decltype(dst_tmp_vector)::type;
// copy data from dst_tmp_vector to dst_buf // copy data from dst_tmp_vector to dst_buf
*reinterpret_cast<dst_vector_t*>(p_dst + dst_slice_origin_coord_.GetOffset()) = dst_buf.template Set<dst_vector_t>(
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]; dst_slice_origin_coord_.GetOffset(),
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -1122,7 +1148,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1122,7 +1148,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
} }
} }
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src) template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{ {
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
...@@ -1132,10 +1159,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1132,10 +1159,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, p_src, src_iterator_hacks); RunRead(src_desc, src_buf, src_iterator_hacks);
} }
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
...@@ -1145,7 +1173,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1145,7 +1173,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, p_dst, dst_iterator_hacks); RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
......
...@@ -44,7 +44,11 @@ struct DynamicBuffer ...@@ -44,7 +44,11 @@ struct DynamicBuffer
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i) const __host__ __device__ constexpr const auto Get(index_t i) const
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
return *reinterpret_cast<const X*>(&p_data_[i]); return *reinterpret_cast<const X*>(&p_data_[i]);
#else
return *reinterpret_cast<const X*>(&p_data_[i]);
#endif
} }
template <typename X, template <typename X,
...@@ -54,7 +58,32 @@ struct DynamicBuffer ...@@ -54,7 +58,32 @@ struct DynamicBuffer
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, const X& x) __host__ __device__ void Set(index_t i, const X& x)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<X*>(&p_data_[i]) = x; *reinterpret_cast<X*>(&p_data_[i]) = x;
#else
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value)
{
static_assert(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value,
"wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{
#if 0
*reinterpret_cast<int32x4_t*>(&p_data_[i]) = as_type<int32x4_t>(x);
#else
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x4_t*>(&x);
#endif
}
}
else
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
......
...@@ -147,6 +147,11 @@ ...@@ -147,6 +147,11 @@
#define CK_WORKAROUND_SWDEV_XXXXXX 1 #define CK_WORKAROUND_SWDEV_XXXXXX 1
#endif #endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
namespace ck { namespace ck {
enum AddressSpace enum AddressSpace
......
...@@ -650,6 +650,105 @@ using int32x16_t = typename vector_type<int32_t, 16>::type; ...@@ -650,6 +650,105 @@ using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type; using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type; using int32x64_t = typename vector_type<int32_t, 64>::type;
template <>
struct vector_type<int8_t, 16>
{
using d1_t = int8_t;
typedef int8_t d2_t __attribute__((ext_vector_type(2)));
typedef int8_t d4_t __attribute__((ext_vector_type(4)));
typedef int8_t d8_t __attribute__((ext_vector_type(8)));
typedef int8_t d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
StaticallyIndexedArray<int32x4_t, 1> int32x4_; // hack
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
#if 0
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
#endif
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else if constexpr(is_same<X, int32x4_t>::value) // hack
{
return data_.int32x4_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
#if 0
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
#endif
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else if constexpr(is_same<X, int32x4_t>::value) // hack
{
return data_.int32x4_;
}
}
};
// i8 // i8
using int8x2_t = typename vector_type<int8_t, 2>::type; using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
......
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,8 +62,8 @@ int main(int argc, char* argv[]) ...@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
......
...@@ -10,13 +10,14 @@ cmake ...@@ -10,13 +10,14 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D DEVICE_BACKEND="AMD" \ -D DEVICE_BACKEND="AMD" \
-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ -D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm" \ -D CMAKE_PREFIX_PATH="/opt/rocm" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0 -mllvm -print-before=amdgpu-codegenprepare -mllvm -print-module-scope" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment