Commit e09f6e02 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 57271814
...@@ -13,20 +13,20 @@ namespace ck { ...@@ -13,20 +13,20 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <typename ThreadGroup, template <typename ThreadGroup,
typename ElementwiseOperation,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDatas, typename SrcDatas,
typename DstDatas, typename DstDatas,
typename SrcDescs, typename SrcDescs,
typename DstDescs, typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorDim, index_t VectorDim,
index_t ScalarPerVector, index_t ScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags, typename ThreadTransferDstResetCoordinateAfterRunFlags>
InMemoryDataOperationEnum... DstInMemOps>
struct ThreadGroupTensorSliceTransfer_v7 struct ThreadGroupTensorSliceTransfer_v7
{ {
static constexpr index_t nDim = static constexpr index_t nDim =
...@@ -147,13 +147,13 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -147,13 +147,13 @@ struct ThreadGroupTensorSliceTransfer_v7
SrcDescs, SrcDescs,
DstDescs, DstDescs,
ElementwiseOperation, ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths), decltype(thread_slice_lengths),
DimAccessOrder, DimAccessOrder,
VectorDim, VectorDim,
ScalarPerVector, ScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags>;
DstInMemOps...>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
...@@ -549,25 +549,26 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -549,25 +549,26 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, // typename DstData,
decltype(c_ds_descs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, // typename DstData,
decltype(c_ds_descs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
Sequence<true, false, false>, // bool ThreadTransferSrcResetCoordinateAfterRunFlags Sequence<true, false, false>, // bool ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // bool ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRunFlags
EGlobalMemoryDataOperation> // DstInMemOp,
{c_ds_descs, {c_ds_descs,
make_tuple(make_multi_index(0, 0, 0, 0), make_tuple(make_multi_index(0, 0, 0, 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
......
...@@ -24,13 +24,13 @@ template <typename SrcDatas, ...@@ -24,13 +24,13 @@ template <typename SrcDatas,
typename SrcDescs, typename SrcDescs,
typename DstDescs, typename DstDescs,
typename ElementwiseOperation, typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorDim, index_t VectorDim,
index_t ScalarPerVector, index_t ScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<...> typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<...> typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
InMemoryDataOperationEnum... DstInMemOps>
struct ThreadwiseTensorSliceTransfer_v7 struct ThreadwiseTensorSliceTransfer_v7
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -165,7 +165,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -165,7 +165,8 @@ struct ThreadwiseTensorSliceTransfer_v7
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]); dst_coords_[i]);
constexpr auto DstInMemOp = make_tuple(DstInMemOps...)[i]; constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>( dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(), dst_coords_[i].GetOffset(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment