Commit 25e35b59 authored by Chao Liu's avatar Chao Liu
Browse files

rename, clean

parent 8a60a329
add_example_executable(example_gemm_bias_add_fastgelu_xdl_fp16 gemm_bias_add_fastgelu_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_fastgelu_xdl_fp16 gemm_bias_fastgelu_xdl_fp16.cpp)
...@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) ...@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm) add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(03_gemm_bias_fastgelu) add_subdirectory(03_gemm_bias_add_fastgelu)
add_subdirectory(04_gemm_bias_relu_add) add_subdirectory(04_gemm_bias_relu_add)
add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(07_conv2d_fwd_bias_relu_add)
......
...@@ -33,7 +33,7 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -33,7 +33,7 @@ struct DeviceGemmMultipleD : public BaseOperator
ck::index_t StrideE, ck::index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op); CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -489,7 +489,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -489,7 +489,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
...@@ -500,7 +500,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -500,7 +500,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
block_2_etile_map_)) block_2_etile_map_))
{ {
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -512,7 +512,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -512,7 +512,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n); d_grid_desc_m_n);
}); });
} }
...@@ -538,13 +538,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -538,13 +538,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor> NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
...@@ -625,10 +626,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -625,10 +626,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray< ck::StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>, NumDTensor>,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -782,7 +783,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -782,7 +783,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
return str.str(); return str.str();
} }
}; // namespace device };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -14,18 +14,24 @@ ...@@ -14,18 +14,24 @@
namespace ck { namespace ck {
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename DsDataType, typename DsDataType,
typename FloatC, typename FloatE,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -55,8 +61,8 @@ template <typename FloatAB, ...@@ -55,8 +61,8 @@ template <typename FloatAB,
index_t BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
...@@ -153,12 +159,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -153,12 +159,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -168,7 +174,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -168,7 +174,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
return false; return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
...@@ -182,7 +188,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -182,7 +188,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return false; return false;
} }
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{ {
return false; return false;
} }
...@@ -199,58 +205,59 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -199,58 +205,59 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n, e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock; return e_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
c_grid_desc_m_n); e_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2ETileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid, FloatE* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const StaticallyIndexedArray<CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, const StaticallyIndexedArray<EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>& NumDTensor>&
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: use tuple ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& // type from E
c_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const Block2CTileMap& block_2_ctile_map) e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -266,17 +273,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -266,17 +273,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_etile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ {
return; return;
} }
...@@ -537,27 +544,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -537,27 +544,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
#if 1 #if 1
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r3<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data, FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data, remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data, remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatC, // typename DstData, FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
...@@ -568,37 +575,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -568,37 +575,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op}; cde_element_op};
#else #else
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data, FloatCShuffle, // typename Src0Data,
FloatC, // typename DstData, FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op}; cde_element_op};
#endif #endif
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
...@@ -611,8 +618,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -611,8 +618,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
M4, M4,
1>>{}; 1>>{};
// space filling curve for shuffled blockwise C in global mem // space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_c_global = constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
Sequence<1, Sequence<1,
...@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
...@@ -640,37 +647,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -640,37 +647,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global // each block copy its data from LDS to global
#if 1 #if 1
c_shuffle_block_copy_lds_to_global.Run( cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0], ds_grid_buf[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1], ds_grid_buf[I1],
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); e_grid_buf);
#else #else
c_shuffle_block_copy_lds_to_global.Run( cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); e_grid_buf);
#endif #endif
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_cde_block.GetForwardStep(access_id);
// move on Ds // move on Ds
c_shuffle_block_copy_lds_to_global.MoveSrc1SliceWindow( cde_block_copy_lds_and_global.MoveSrc1SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step); ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step);
c_shuffle_block_copy_lds_to_global.MoveSrc2SliceWindow( cde_block_copy_lds_and_global.MoveSrc2SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step); ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step);
// move on C // move on E
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( cde_block_copy_lds_and_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment