Commit 925a8d78 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 681ede91
...@@ -18,7 +18,6 @@ template <index_t BlockSize, ...@@ -18,7 +18,6 @@ template <index_t BlockSize,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
...@@ -58,14 +59,13 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -58,14 +59,13 @@ struct BlockwiseTensorSliceTransfer_v4r1
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
...@@ -77,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -77,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
SrcElementwiseOperation, SrcElementwiseOperation,
DstElementwiseOperation, DstElementwiseOperation,
DstInMemOp, DstInMemOp,
......
...@@ -17,7 +17,6 @@ template <index_t BlockSize, ...@@ -17,7 +17,6 @@ template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -33,6 +32,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -33,6 +32,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
...@@ -49,14 +50,13 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -49,14 +50,13 @@ struct BlockwiseTensorSliceTransfer_v6r1
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
...@@ -68,7 +68,7 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -68,7 +68,7 @@ struct BlockwiseTensorSliceTransfer_v6r1
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
...@@ -118,7 +118,7 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -118,7 +118,7 @@ struct BlockwiseTensorSliceTransfer_v6r1
SrcDesc, SrcDesc,
DstDesc, DstDesc,
ElementwiseOperation, ElementwiseOperation,
ThreadSliceLengths, decltype(thread_slice_lengths),
DimAccessOrder, DimAccessOrder,
VectorDim, VectorDim,
ScalarPerVector, ScalarPerVector,
......
...@@ -17,7 +17,6 @@ template <index_t BlockSize, ...@@ -17,7 +17,6 @@ template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
...@@ -65,14 +66,13 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -65,14 +66,13 @@ struct BlockwiseTensorSliceTransfer_v6r3
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
...@@ -84,7 +84,7 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -84,7 +84,7 @@ struct BlockwiseTensorSliceTransfer_v6r3
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin( threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin); src0_desc, src0_block_slice_origin + thread_data_idx_begin);
...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v6r3
Src2Desc, Src2Desc,
DstDesc, DstDesc,
ElementwiseOperation, ElementwiseOperation,
ThreadSliceLengths, decltype(thread_slice_lengths),
DimAccessOrder, DimAccessOrder,
VectorDim, VectorDim,
ScalarPerVector, ScalarPerVector,
......
...@@ -56,50 +56,46 @@ __global__ void ...@@ -56,50 +56,46 @@ __global__ void
block_2_ctile_map); block_2_ctile_map);
} }
template <index_t BlockSize, template <
typename FloatAB, index_t BlockSize,
typename FloatAcc, typename FloatAB,
typename FloatC, typename FloatAcc,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, typename FloatC,
typename AGridDesc_K0_M_K1, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename BGridDesc_K0_N_K1, typename AGridDesc_K0_M_K1,
typename CGridDesc_M_N, typename BGridDesc_K0_N_K1,
typename AElementwiseOperation, typename CGridDesc_M_N,
typename BElementwiseOperation, typename AElementwiseOperation,
typename CElementwiseOperation, typename BElementwiseOperation,
index_t MPerBlock, typename CElementwiseOperation,
index_t NPerBlock, index_t MPerBlock,
index_t K0PerBlock, index_t NPerBlock,
index_t MPerXdl, index_t K0PerBlock,
index_t NPerXdl, index_t MPerXdl,
index_t K1Value, index_t NPerXdl,
index_t MRepeat, index_t K1Value,
index_t NRepeat, index_t MRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1,
index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BBlockLdsExtraN,
bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle,
index_t MRepeatPerShuffle_CCopy, index_t CShuffleNRepeatPerShuffle,
index_t NRepeatPerShuffle_CCopy, typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t MRepeatThread_CCopy, index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t MThread_CCopy,
index_t NRepeatThread_CCopy,
index_t NThread_CCopy,
index_t NScalarPerVector_CCopy>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -363,7 +359,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -363,7 +359,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -394,7 +389,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -394,7 +389,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -500,54 +494,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -500,54 +494,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
// shuffle and write out // shuffle C and write out
{ {
#if 0 static_assert(MRepeat % CShuffleMRepeatPerShuffle == 0 &&
// TODO: make it tunable NRepeat % CShuffleNRepeatPerShuffle == 0,
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#elif 0
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 2;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 16;
constexpr index_t NRepeatThread_CCopy = 2;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#endif
static_assert(MRepeat % MRepeatPerShuffle_CCopy == 0 &&
NRepeat % NRepeatPerShuffle_CCopy == 0,
"wrong!"); "wrong!");
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr index_t MRepeatPerThread_CCopy =
MRepeatPerShuffle_CCopy / MRepeatThread_CCopy;
constexpr index_t NRepeatPerThread_CCopy =
NRepeatPerShuffle_CCopy / NRepeatThread_CCopy;
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -568,10 +523,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -568,10 +523,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl = constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeatPerShuffle_CCopy>{}, Number<CShuffleMRepeatPerShuffle>{},
Number<MWave * MPerXdl>{}, Number<MWave * MPerXdl>{},
I1, I1,
Number<NRepeatPerShuffle_CCopy>{}, Number<CShuffleNRepeatPerShuffle>{},
Number<NWave * NPerXdl>{})); Number<NWave * NPerXdl>{}));
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -583,12 +538,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -583,12 +538,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform( make_pass_through_transform(
Number<MRepeatPerShuffle_CCopy>{}), // M0 (MRepeat) per shuffle Number<CShuffleMRepeatPerShuffle>{}), // M0 (MRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock make_freeze_transform(I0), // freeze nblock
make_pass_through_transform( make_pass_through_transform(
Number<NRepeatPerShuffle_CCopy>{}), // N0 (NRepeat) per shuffle Number<CShuffleNRepeatPerShuffle>{}), // N0 (NRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -635,61 +590,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -635,61 +590,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto c_thread_copy_vgpr_to_lds =
FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<MRepeatPerShuffle_CCopy, NRepeatPerShuffle_CCopy, I1, I1, M2, I1, M4, I1>, Sequence<CShuffleMRepeatPerShuffle,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, CShuffleNRepeatPerShuffle,
7, I1,
1, I1,
InMemoryDataOperationEnum_t::Set, M2,
1, I1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, M4,
make_multi_index(0, I1>,
0, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
m_thread_data_on_block_idx[I1], 7,
n_thread_data_on_block_idx[I1], 1,
m_thread_data_on_block_idx[I2], InMemoryDataOperationEnum_t::Set,
m_thread_data_on_block_idx[I3], 1,
m_thread_data_on_block_idx[I4], true>{
n_thread_data_on_block_idx[I2]), c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
ck::tensor_operation::element_wise::PassThrough{}}; make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
MRepeatPerShuffle_CCopy, CShuffleMRepeatPerShuffle,
MPerBlock_CCopy, MWave * MPerXdl,
1,
NRepeatPerShuffle_CCopy,
NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1,
MRepeatPerShuffle_CCopy,
MPerThread_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1,
MRepeatPerThread_CCopy,
MThread_CCopy,
1, 1,
NRepeatPerThread_CCopy, CShuffleNRepeatPerShuffle,
NThread_CCopy>, // ThreadClusterLengths, NWave * NPerXdl>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData, FloatC, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
NScalarPerVector_CCopy, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, {c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
...@@ -697,22 +649,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -697,22 +649,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_element_op}; c_element_op};
constexpr auto mrepeat_forward_step = constexpr auto mrepeat_forward_step =
make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0); make_multi_index(0, CShuffleMRepeatPerShuffle, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step = constexpr auto nrepeat_forward_step =
make_multi_index(0, 0, 0, 0, NRepeatPerShuffle_CCopy, 0); make_multi_index(0, 0, 0, 0, CShuffleNRepeatPerShuffle, 0);
constexpr auto nrepeat_backward_step = constexpr auto nrepeat_backward_step =
make_multi_index(0, 0, 0, 0, -NRepeatPerShuffle_CCopy, 0); make_multi_index(0, 0, 0, 0, -CShuffleNRepeatPerShuffle, 0);
static_for<0, MRepeat, MRepeatPerShuffle_CCopy>{}([&](auto mrepeat_iter) { static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mrepeat_iter) {
constexpr auto mrepeat = mrepeat_iter; constexpr auto mrepeat = mrepeat_iter;
static_for<0, NRepeat, NRepeatPerShuffle_CCopy>{}([&](auto nrepeat_iter) { static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep = constexpr bool nrepeat_forward_sweep =
(mrepeat % (2 * MRepeatPerShuffle_CCopy) == 0); (mrepeat % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nrepeat_value = constexpr index_t nrepeat_value =
nrepeat_forward_sweep ? nrepeat_iter nrepeat_forward_sweep
: (NRepeat - nrepeat_iter - NRepeatPerShuffle_CCopy); ? nrepeat_iter
: (NRepeat - nrepeat_iter - CShuffleNRepeatPerShuffle);
constexpr auto nrepeat = Number<nrepeat_value>{}; constexpr auto nrepeat = Number<nrepeat_value>{};
...@@ -739,7 +692,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -739,7 +692,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// move on nrepeat dimension // move on nrepeat dimension
if constexpr(nrepeat_forward_sweep && if constexpr(nrepeat_forward_sweep &&
(nrepeat < NRepeat - NRepeatPerShuffle_CCopy)) (nrepeat < NRepeat - CShuffleNRepeatPerShuffle))
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
...@@ -754,7 +707,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -754,7 +707,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}); });
// move on mrepeat dimension // move on mrepeat dimension
if constexpr(mrepeat < MRepeat - MRepeatPerShuffle_CCopy) if constexpr(mrepeat < MRepeat - CShuffleMRepeatPerShuffle)
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......
...@@ -68,52 +68,48 @@ __global__ void ...@@ -68,52 +68,48 @@ __global__ void
block_2_ctile_map); block_2_ctile_map);
} }
template <index_t BlockSize, template <
typename FloatAB, index_t BlockSize,
typename FloatAcc, typename FloatAB,
typename FloatC, typename FloatAcc,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, typename FloatC,
typename AGridDesc_K0_M_K1, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename BGridDesc_K0_N_K1, typename AGridDesc_K0_M_K1,
typename CGridDesc_M_N, typename BGridDesc_K0_N_K1,
typename C0GridDesc_M_N, typename CGridDesc_M_N,
typename C1GridDesc_M_N, typename C0GridDesc_M_N,
typename AElementwiseOperation, typename C1GridDesc_M_N,
typename BElementwiseOperation, typename AElementwiseOperation,
typename CElementwiseOperation, typename BElementwiseOperation,
index_t MPerBlock, typename CElementwiseOperation,
index_t NPerBlock, index_t MPerBlock,
index_t K0PerBlock, index_t NPerBlock,
index_t MPerXdl, index_t K0PerBlock,
index_t NPerXdl, index_t MPerXdl,
index_t K1Value, index_t NPerXdl,
index_t MRepeat, index_t K1Value,
index_t NRepeat, index_t MRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1,
index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BBlockLdsExtraN,
bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle,
index_t MRepeatPerShuffle_CCopy, index_t CShuffleNRepeatPerShuffle,
index_t NRepeatPerShuffle_CCopy, typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t MRepeatThread_CCopy, index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t MThread_CCopy,
index_t NRepeatThread_CCopy,
index_t NThread_CCopy,
index_t NScalarPerVector_CCopy>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -402,7 +398,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -402,7 +398,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -433,7 +428,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -433,7 +428,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -539,54 +533,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -539,54 +533,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
// shuffle and write out // shuffle C and write out
{ {
#if 0 static_assert(MRepeat % CShuffleMRepeatPerShuffle == 0 &&
// TODO: make it tunable NRepeat % CShuffleNRepeatPerShuffle == 0,
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#elif 0
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 2;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 16;
constexpr index_t NRepeatThread_CCopy = 2;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#endif
static_assert(MRepeat % MRepeatPerShuffle_CCopy == 0 &&
NRepeat % NRepeatPerShuffle_CCopy == 0,
"wrong!"); "wrong!");
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr index_t MRepeatPerThread_CCopy =
MRepeatPerShuffle_CCopy / MRepeatThread_CCopy;
constexpr index_t NRepeatPerThread_CCopy =
NRepeatPerShuffle_CCopy / NRepeatThread_CCopy;
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -607,10 +562,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -607,10 +562,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl = constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeatPerShuffle_CCopy>{}, Number<CShuffleMRepeatPerShuffle>{},
Number<MWave * MPerXdl>{}, Number<MWave * MPerXdl>{},
I1, I1,
Number<NRepeatPerShuffle_CCopy>{}, Number<CShuffleNRepeatPerShuffle>{},
Number<NWave * NPerXdl>{})); Number<NWave * NPerXdl>{}));
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -622,12 +577,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -622,12 +577,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform( make_pass_through_transform(
Number<MRepeatPerShuffle_CCopy>{}), // M0 (MRepeat) per shuffle Number<CShuffleMRepeatPerShuffle>{}), // M0 (MRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock make_freeze_transform(I0), // freeze nblock
make_pass_through_transform( make_pass_through_transform(
Number<NRepeatPerShuffle_CCopy>{}), // N0 (NRepeat) per shuffle Number<CShuffleNRepeatPerShuffle>{}), // N0 (NRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -674,51 +629,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -674,51 +629,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto c_thread_copy_vgpr_to_lds =
FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<MRepeatPerShuffle_CCopy, NRepeatPerShuffle_CCopy, I1, I1, M2, I1, M4, I1>, Sequence<CShuffleMRepeatPerShuffle,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, CShuffleNRepeatPerShuffle,
7, I1,
1, I1,
InMemoryDataOperationEnum_t::Set, M2,
1, I1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, M4,
make_multi_index(0, I1>,
0, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
m_thread_data_on_block_idx[I1], 7,
n_thread_data_on_block_idx[I1], 1,
m_thread_data_on_block_idx[I2], InMemoryDataOperationEnum_t::Set,
m_thread_data_on_block_idx[I3], 1,
m_thread_data_on_block_idx[I4], true>{
n_thread_data_on_block_idx[I2]), c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
ck::tensor_operation::element_wise::PassThrough{}}; make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
MRepeatPerShuffle_CCopy, CShuffleMRepeatPerShuffle,
MPerBlock_CCopy, MWave * MPerXdl,
1,
NRepeatPerShuffle_CCopy,
NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1,
MRepeatPerShuffle_CCopy,
MPerThread_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1,
MRepeatPerThread_CCopy,
MThread_CCopy,
1, 1,
NRepeatPerThread_CCopy, CShuffleNRepeatPerShuffle,
NThread_CCopy>, // ThreadClusterLengths, NWave * NPerXdl>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename Src0Data, FloatC, // typename Src0Data,
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
...@@ -728,13 +680,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -728,13 +680,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
decltype(c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c1_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c1_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
NScalarPerVector_CCopy, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, {c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
...@@ -746,22 +698,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -746,22 +698,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_element_op}; c_element_op};
constexpr auto mrepeat_forward_step = constexpr auto mrepeat_forward_step =
make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0); make_multi_index(0, CShuffleMRepeatPerShuffle, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step = constexpr auto nrepeat_forward_step =
make_multi_index(0, 0, 0, 0, NRepeatPerShuffle_CCopy, 0); make_multi_index(0, 0, 0, 0, CShuffleNRepeatPerShuffle, 0);
constexpr auto nrepeat_backward_step = constexpr auto nrepeat_backward_step =
make_multi_index(0, 0, 0, 0, -NRepeatPerShuffle_CCopy, 0); make_multi_index(0, 0, 0, 0, -CShuffleNRepeatPerShuffle, 0);
static_for<0, MRepeat, MRepeatPerShuffle_CCopy>{}([&](auto mrepeat_iter) { static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mrepeat_iter) {
constexpr auto mrepeat = mrepeat_iter; constexpr auto mrepeat = mrepeat_iter;
static_for<0, NRepeat, NRepeatPerShuffle_CCopy>{}([&](auto nrepeat_iter) { static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep = constexpr bool nrepeat_forward_sweep =
(mrepeat % (2 * MRepeatPerShuffle_CCopy) == 0); (mrepeat % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nrepeat_value = constexpr index_t nrepeat_value =
nrepeat_forward_sweep ? nrepeat_iter nrepeat_forward_sweep
: (NRepeat - nrepeat_iter - NRepeatPerShuffle_CCopy); ? nrepeat_iter
: (NRepeat - nrepeat_iter - CShuffleNRepeatPerShuffle);
constexpr auto nrepeat = Number<nrepeat_value>{}; constexpr auto nrepeat = Number<nrepeat_value>{};
...@@ -792,7 +745,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -792,7 +745,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// move on nrepeat dimension // move on nrepeat dimension
if constexpr(nrepeat_forward_sweep && if constexpr(nrepeat_forward_sweep &&
(nrepeat < NRepeat - NRepeatPerShuffle_CCopy)) (nrepeat < NRepeat - CShuffleNRepeatPerShuffle))
{ {
c_block_copy_lds_to_global.MoveSrc1SliceWindow( c_block_copy_lds_to_global.MoveSrc1SliceWindow(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
...@@ -823,7 +776,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -823,7 +776,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}); });
// move on mrepeat dimension // move on mrepeat dimension
if constexpr(mrepeat < MRepeat - MRepeatPerShuffle_CCopy) if constexpr(mrepeat < MRepeat - CShuffleMRepeatPerShuffle)
{ {
c_block_copy_lds_to_global.MoveSrc1SliceWindow( c_block_copy_lds_to_global.MoveSrc1SliceWindow(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......
...@@ -18,45 +18,41 @@ namespace device { ...@@ -18,45 +18,41 @@ namespace device {
// out[N, Ho, Wo, K] = // out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] // activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
template <typename InDataType, template <
typename WeiDataType, typename InDataType,
typename OutDataType, typename WeiDataType,
typename AccDataType, typename OutDataType,
typename InElementwiseOperation, typename AccDataType,
typename WeiElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename WeiElementwiseOperation,
ck::index_t BlockSize, typename OutElementwiseOperation,
ck::index_t MPerBlock, ck::index_t BlockSize,
ck::index_t NPerBlock, ck::index_t MPerBlock,
ck::index_t K0PerBlock, ck::index_t NPerBlock,
ck::index_t K1, ck::index_t K0PerBlock,
ck::index_t MPerXDL, ck::index_t K1,
ck::index_t NPerXDL, ck::index_t MPerXDL,
ck::index_t MXdlPerWave, ck::index_t NPerXDL,
ck::index_t NXdlPerWave, ck::index_t MXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1, ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM, bool ABlockLdsAddExtraM,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferDstScalarPerVector_K1,
ck::index_t BBlockTransferDstScalarPerVector_K1, bool BBlockLdsAddExtraN,
bool BBlockLdsAddExtraN, index_t CShuffleMRepeatPerShuffle,
index_t MRepeatPerShuffle_CCopy, index_t CShuffleNRepeatPerShuffle,
index_t NRepeatPerShuffle_CCopy, typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t MRepeatThread_CCopy, index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t MThread_CCopy,
index_t NRepeatThread_CCopy,
index_t NThread_CCopy,
index_t NScalarPerVector_CCopy>
struct struct
DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -257,7 +253,6 @@ struct ...@@ -257,7 +253,6 @@ struct
K1, K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
...@@ -266,7 +261,6 @@ struct ...@@ -266,7 +261,6 @@ struct
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
...@@ -275,13 +269,10 @@ struct ...@@ -275,13 +269,10 @@ struct
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
MRepeatPerShuffle_CCopy, CShuffleMRepeatPerShuffle,
NRepeatPerShuffle_CCopy, CShuffleNRepeatPerShuffle,
MRepeatThread_CCopy, CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
MThread_CCopy, CBlockTransferScalarPerVector_NWaveNPerXdl>;
NRepeatThread_CCopy,
NThread_CCopy,
NScalarPerVector_CCopy>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -17,45 +17,41 @@ namespace tensor_operation { ...@@ -17,45 +17,41 @@ namespace tensor_operation {
namespace device { namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType, template <
typename WeiDataType, typename InDataType,
typename OutDataType, typename WeiDataType,
typename AccDataType, typename OutDataType,
typename InElementwiseOperation, typename AccDataType,
typename WeiElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename WeiElementwiseOperation,
ck::index_t BlockSize, typename OutElementwiseOperation,
ck::index_t MPerBlock, ck::index_t BlockSize,
ck::index_t NPerBlock, ck::index_t MPerBlock,
ck::index_t K0PerBlock, ck::index_t NPerBlock,
ck::index_t K1, ck::index_t K0PerBlock,
ck::index_t MPerXdl, ck::index_t K1,
ck::index_t NPerXdl, ck::index_t MPerXdl,
ck::index_t MXdlPerWave, ck::index_t NPerXdl,
ck::index_t NXdlPerWave, ck::index_t MXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1, ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM, bool ABlockLdsAddExtraM,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferDstScalarPerVector_K1,
ck::index_t BBlockTransferDstScalarPerVector_K1, bool BBlockLdsAddExtraN,
bool BBlockLdsAddExtraN, index_t CShuffleMRepeatPerShuffle,
index_t MRepeatPerShuffle_CCopy, index_t CShuffleNRepeatPerShuffle,
index_t NRepeatPerShuffle_CCopy, typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t MRepeatThread_CCopy, index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t MThread_CCopy,
index_t NRepeatThread_CCopy,
index_t NThread_CCopy,
index_t NScalarPerVector_CCopy>
struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
...@@ -238,7 +234,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -238,7 +234,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
K1, K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
...@@ -247,7 +242,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -247,7 +242,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
...@@ -256,13 +250,10 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -256,13 +250,10 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
MRepeatPerShuffle_CCopy, CShuffleMRepeatPerShuffle,
NRepeatPerShuffle_CCopy, CShuffleNRepeatPerShuffle,
MRepeatThread_CCopy, CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
MThread_CCopy, CBlockTransferScalarPerVector_NWaveNPerXdl>;
NRepeatThread_CCopy,
NThread_CCopy,
NScalarPerVector_CCopy>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough_v2; ...@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough_v2;
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off // clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| MRepeatPer| NRepeatPer| MRepeat| MThread| NRepeat| NThread| NScalarPer| // | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| Shuffle| Shuffle| Thread| _CCopy| Thread| _CCopy| Vector| // | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | _CCopy| _CCopy| _CCopy| | _CCopy| | _CCopy| // | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, 1, 32, 1, 8, 8>; <InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
// clang-format on // clang-format on
template <typename TIn, template <typename TIn,
......
...@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd_v2; ...@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd_v2;
// clang-format off // clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| MRepeatPer| NRepeatPer| MRepeat| MThread| NRepeat| NThread| NScalarPer| // | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| Shuffle| Shuffle| Thread| _CCopy| Thread| _CCopy| Vector| // | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | _CCopy| _CCopy| _CCopy| | _CCopy| | _CCopy| // | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, 1, 32, 1, 8, 8>; <InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
// clang-format on // clang-format on
template <typename TIn, template <typename TIn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment