Commit e09f6e02 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 57271814
......@@ -13,20 +13,20 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <typename ThreadGroup,
typename ElementwiseOperation,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
InMemoryDataOperationEnum... DstInMemOps>
typename ThreadTransferDstResetCoordinateAfterRunFlags>
struct ThreadGroupTensorSliceTransfer_v7
{
static constexpr index_t nDim =
......@@ -147,13 +147,13 @@ struct ThreadGroupTensorSliceTransfer_v7
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
DstInMemOps...>;
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadwiseTransfer threadwise_transfer_;
};
......
......@@ -548,26 +548,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]);
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
ThisThreadBlock, // ThreadGroup
Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, // typename DstData,
decltype(c_ds_descs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
Sequence<true, false, false>, // bool ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // bool ThreadTransferDstResetCoordinateAfterRunFlags
EGlobalMemoryDataOperation> // DstInMemOp,
Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_descs,
make_tuple(make_multi_index(0, 0, 0, 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
......
......@@ -24,13 +24,13 @@ template <typename SrcDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<...>
typename DstResetCoordinateAfterRunFlags, // Sequence<...>
InMemoryDataOperationEnum... DstInMemOps>
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
struct ThreadwiseTensorSliceTransfer_v7
{
static constexpr auto I0 = Number<0>{};
......@@ -165,7 +165,8 @@ struct ThreadwiseTensorSliceTransfer_v7
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr auto DstInMemOp = make_tuple(DstInMemOps...)[i];
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment