"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "bb25d249f1761668fda52e1e73cc2a4b178e9e87"
Commit 4e57b30a authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent c03045ce
...@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds> ...@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate; struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateIterator; struct TensorCoordinateStep;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -252,17 +252,16 @@ struct TensorCoordinate ...@@ -252,17 +252,16 @@ struct TensorCoordinate
}; };
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateIterator struct TensorCoordinateStep
{ {
// TODO make these private // TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>; using VisibleIndex = MultiIndex<NDimVisible>;
public: public:
__host__ __device__ constexpr TensorCoordinateIterator() = default; __host__ __device__ constexpr TensorCoordinateStep() = default;
__host__ __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
__device__ constexpr TensorCoordinateIterator(const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
} }
...@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens ...@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens
// UpdateLowerIndexHack: Sequence<...> // UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex // HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack> template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_tensor_coordinate_iterator( __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator( ...@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
return TensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{ return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
idx_diff_visible, do_transforms}; do_transforms};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
make_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) const VisibleIndex& idx_diff_visible)
{ {
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_tensor_coordinate_iterator( return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{}); TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
} }
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator> template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
TensorCoord& coord, TensorCoord& coord,
const TensorCoordIterator& coord_iterator) const TensorCoordStep& coord_step)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
...@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens ...@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff // initialize visible index diff
set_container_subset(idx_diff_hidden, set_container_subset(
TensorDesc::GetVisibleDimensionIds(), idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
coord_iterator.GetVisibleIndexDiff());
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
...@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens ...@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
auto idx_hidden_pick_visible = auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_iterator.do_transforms_[itran]) if(coord_step.do_transforms_[itran])
{ {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
...@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens ...@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
MultiIndex<dims_low.Size()> idx_diff_low; MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for Merge using hack // HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{}); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
...@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate( ...@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinateIterator_t = decltype(make_tensor_coordinate_iterator( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
......
...@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
......
...@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
......
...@@ -84,11 +84,11 @@ template <index_t BlockSize, ...@@ -84,11 +84,11 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
...@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
...@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
c_thread_buf, c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
...@@ -145,11 +145,11 @@ template <index_t BlockSize, ...@@ -145,11 +145,11 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r2 struct GridwiseGemmDlops_km_kn_mn_v1r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowIteratorHacks{}; AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
...@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
...@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc, a_block_slice_copy_step,
a_block_slice_copy_step, a_k_m0_m1_global_move_slice_window_step_hack);
a_k_m0_m1_global_move_slice_window_iterator_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow( b_block_slice_copy_step,
b_k_n0_n1_grid_desc, b_k_n0_n1_global_move_slice_window_step_hack);
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc, a_block_slice_copy_step,
a_block_slice_copy_step, a_k_m0_m1_global_move_slice_window_step_hack);
a_k_m0_m1_global_move_slice_window_iterator_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow( b_block_slice_copy_step,
b_k_n0_n1_grid_desc, b_k_n0_n1_global_move_slice_window_step_hack);
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
...@@ -141,11 +141,11 @@ template <index_t BlockSize, ...@@ -141,11 +141,11 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r3 struct GridwiseGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
...@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
AGridMoveSliceWindowIteratorHacks{}); b_blockwise_copy.MoveSrcSliceWindow(
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
...@@ -42,11 +42,11 @@ template <index_t BlockSize, ...@@ -42,11 +42,11 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks, typename AGlobalStepHacks,
typename BGlobalIteratorHacks, typename BGlobalStepHacks,
typename CGlobalIteratorHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
...@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_e_k_global_move_slice_window_iterator_hack = constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
AGlobalMoveSliceWindowIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = BGlobalMoveSliceWindowStepHacks{};
BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
...@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
...@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
...@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global = const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread; k_block_data_on_global + k_thread_id * KPerThread;
...@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
c_global_buf, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_step_hacks);
} }
} }
......
...@@ -126,11 +126,11 @@ template <index_t BlockSize, ...@@ -126,11 +126,11 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
...@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
AGridMoveSliceWindowIteratorHacks{}; constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
...@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
...@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_iterator_hack); a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_iterator_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
...@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_blk_buf_, c_blk_buf_,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
} }
#else #else
{ {
...@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
...@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
...@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
...@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
......
...@@ -11,7 +11,7 @@ namespace ck { ...@@ -11,7 +11,7 @@ namespace ck {
// 1. Desc is known at compile-time // 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer // 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time // 3. OriginIdx is known at compile-time
// 4. use #-iterator // 4. use #-step
template <typename Data, template <typename Data,
typename Desc, typename Desc,
typename SliceLengths, typename SliceLengths,
......
...@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward iterators // make forward steps
const auto src_forward_iterators = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, forward_step, src_iterator_hacks[I0][i]); src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto src_backward_iterators = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, backward_step, src_iterator_hacks[I1][i]); src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
} }
} }
}); });
...@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_iterator = const auto src_reset_step =
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
template <typename DstBuffer, typename DstIteratorHacks> template <typename DstBuffer, typename DstStepHacks>
__device__ void RunWrite(const DstDesc& dst_desc, __device__ void
DstBuffer& dst_buf, RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto ordered_dst_access_lengths = constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward iterators // make forward steps
const auto dst_forward_iterators = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
}); });
const auto forward_iterator = make_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, forward_step, dst_iterator_hacks[I0][i]); dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
return forward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto dst_backward_iterators = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
}); });
const auto backward_iterator = make_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, backward_step, dst_iterator_hacks[I1][i]); dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
return backward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
} }
} }
}); });
...@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_iterator = const auto dst_reset_step =
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks = constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_iterator_hacks); RunRead(src_desc, src_buf, src_step_hacks);
} }
template <typename DstBuffer> template <typename DstBuffer>
...@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks = constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_iterator_hacks); RunWrite(dst_desc, dst_buf, dst_step_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx, const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx =
...@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_iterator( const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
...@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
...@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// 2. SrcBuffer is DynamicBuffer // 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time // 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time // 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator // 5. use #-step
// 2. dst: // 2. dst:
// 1. DstDesc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer // 2. DstBuffer is StaticBuffer
...@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
...@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_iterator = constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector; vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
...@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = const auto src_slice_move_step_iter =
make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx)); make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
......
...@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
using BKNGridDesc = decltype(b_k_n_grid_desc); using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
...@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize, GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
...@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
...@@ -249,16 +249,16 @@ extern "C" __global__ void ...@@ -249,16 +249,16 @@ extern "C" __global__ void
using BKNGridDesc = decltype(b_k_n_grid_desc); using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
...@@ -266,21 +266,21 @@ extern "C" __global__ void ...@@ -266,21 +266,21 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize, GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
...@@ -320,11 +320,11 @@ extern "C" __global__ void ...@@ -320,11 +320,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
constexpr auto a_k_m0_m1_grid_desc_tmp = constexpr auto a_k_m0_m1_grid_desc_tmp =
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
......
...@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple( using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
...@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
...@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowStepHacks,
false>; false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
...@@ -243,12 +243,12 @@ extern "C" __global__ void ...@@ -243,12 +243,12 @@ extern "C" __global__ void
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2]; constexpr auto c_m_n_grid_desc = descs[I2];
using AGridIteratorHacks = decltype(make_tuple( using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
...@@ -256,25 +256,25 @@ extern "C" __global__ void ...@@ -256,25 +256,25 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
...@@ -316,11 +316,11 @@ extern "C" __global__ void ...@@ -316,11 +316,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowStepHacks,
false>; false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
......
...@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridIteratorHacks = decltype(make_tuple( using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridIteratorHacks = using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
...@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
...@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowStepHacks,
false>; false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
...@@ -247,12 +247,12 @@ extern "C" __global__ void ...@@ -247,12 +247,12 @@ extern "C" __global__ void
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridIteratorHacks = decltype(make_tuple( using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridIteratorHacks = using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
...@@ -260,25 +260,25 @@ extern "C" __global__ void ...@@ -260,25 +260,25 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
...@@ -316,11 +316,11 @@ extern "C" __global__ void ...@@ -316,11 +316,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowStepHacks,
false>; false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
......
...@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, ...@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridIteratorHacks = using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
...@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, ...@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridIteratorHacks = decltype(make_tuple( using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
...@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, ...@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridIteratorHacks = decltype(make_tuple( using CGridStepHacks = decltype(make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
...@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, ...@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
...@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, ...@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{ {
...@@ -254,7 +254,7 @@ extern "C" __global__ void ...@@ -254,7 +254,7 @@ extern "C" __global__ void
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridIteratorHacks = using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
...@@ -266,7 +266,7 @@ extern "C" __global__ void ...@@ -266,7 +266,7 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridIteratorHacks = decltype(make_tuple( using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
...@@ -278,7 +278,7 @@ extern "C" __global__ void ...@@ -278,7 +278,7 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridIteratorHacks = decltype(make_tuple( using CGridStepHacks = decltype(make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
...@@ -294,9 +294,9 @@ extern "C" __global__ void ...@@ -294,9 +294,9 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
...@@ -334,11 +334,11 @@ extern "C" __global__ void ...@@ -334,11 +334,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
......
...@@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const auto in_gemmm_gemmn_grid_desc = descs[I2]; const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
...@@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple( constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
...@@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
...@@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence<1, 3, 7, 0, 2, 4, 5, 6>, Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
6, 6,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_iterator_hacks), decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), >(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
...@@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
wei_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc, in_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_iterator_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_iterator_hacks, in_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ {
......
...@@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const auto in_gemmm_gemmn_grid_desc = descs[I2]; const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
...@@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
...@@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
...@@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
#endif #endif
7, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_iterator_hacks), decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true // CAccessOrderMRepeatNRepeat true // CAccessOrderMRepeatNRepeat
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), >(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
...@@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
out_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc, in_gemmm_gemmn_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_iterator_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_iterator_hacks, in_m0_m1_m2_n_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ {
......
...@@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
in_right_pads); in_right_pads);
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
...@@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
...@@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
...@@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
const auto wei_gemmk_gemmm_grid_desc = descs[I0]; const auto wei_gemmk_gemmm_grid_desc = descs[I0];
...@@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11, GemmCThreadTransferDstScalarPerVector_N11,
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc, wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc, in_gemmk_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, in_gemmk_gemmn0_gemmn1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
float perf = static_cast<float>(calculate_convolution_flops( float perf = static_cast<float>(calculate_convolution_flops(
......
...@@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple( constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
...@@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks = constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
...@@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
...@@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11, GemmCThreadTransferDstScalarPerVector_N11,
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks), decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks), decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks), decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>( decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks, in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks, wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks, in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks, wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ {
......
...@@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
...@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
...@@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<3, 0, 1, 2, 7, 5, 4, 6>, Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_iterator_hacks), decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()), false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_iterator_hacks, out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
float perf = static_cast<float>(calculate_convolution_flops( float perf = static_cast<float>(calculate_convolution_flops(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment