Commit 523a4045 authored by Chao Liu's avatar Chao Liu
Browse files

clang-format

parent 2d31e921
...@@ -17,43 +17,44 @@ namespace ck { ...@@ -17,43 +17,44 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ADataType, template <
typename BDataType, typename ADataType,
typename CDataType, typename BDataType,
typename AccDataType, typename CDataType,
typename ALayout, typename AccDataType,
typename BLayout, typename ALayout,
typename CLayout, typename BLayout,
typename AElementwiseOperation, typename CLayout,
typename BElementwiseOperation, typename AElementwiseOperation,
typename CElementwiseOperation, typename BElementwiseOperation,
ck::index_t BlockSize, typename CElementwiseOperation,
ck::index_t MPerBlock, ck::index_t BlockSize,
ck::index_t NPerBlock, ck::index_t MPerBlock,
ck::index_t K0PerBlock, ck::index_t NPerBlock,
ck::index_t K1, ck::index_t K0PerBlock,
ck::index_t MPerXDL, ck::index_t K1,
ck::index_t NPerXDL, ck::index_t MPerXDL,
ck::index_t MXdlPerWave, ck::index_t NPerXDL,
ck::index_t NXdlPerWave, ck::index_t MXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferThreadClusterArrangeOrder,
ck::index_t ABlockTransferSrcVectorDim, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferSrcScalarPerVector,
bool ABlockLdsAddExtraM, ck::index_t ABlockTransferDstScalarPerVector_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferThreadClusterArrangeOrder,
ck::index_t BBlockTransferSrcVectorDim, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferSrcScalarPerVector,
bool BBlockLdsAddExtraN, ck::index_t BBlockTransferDstScalarPerVector_K1,
index_t CShuffleMXdlPerWavePerShuffle, bool BBlockLdsAddExtraN,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CShuffleNXdlPerWavePerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceGemmShuffleXdl struct DeviceGemmShuffleXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -175,7 +176,6 @@ struct DeviceGemmShuffleXdl ...@@ -175,7 +176,6 @@ struct DeviceGemmShuffleXdl
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl>; CBlockTransferScalarPerVector_NWaveNPerXdl>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -215,8 +215,9 @@ struct DeviceGemmShuffleXdl ...@@ -215,8 +215,9 @@ struct DeviceGemmShuffleXdl
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{ {
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( GridwiseGemm::
c_grid_desc_m_n_); MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
...@@ -295,21 +296,22 @@ struct DeviceGemmShuffleXdl ...@@ -295,21 +296,22 @@ struct DeviceGemmShuffleXdl
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(
nrepeat, kernel,
dim3(grid_size), nrepeat,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.a_grid_desc_k0_m_k1_, arg.p_c_grid_,
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, arg.b_grid_desc_k0_n_k1_,
arg.a_element_op_, arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.b_element_op_, arg.a_element_op_,
arg.c_element_op_, arg.b_element_op_,
arg.block_2_ctile_map_); arg.c_element_op_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -328,21 +330,22 @@ struct DeviceGemmShuffleXdl ...@@ -328,21 +330,22 @@ struct DeviceGemmShuffleXdl
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(
nrepeat, kernel,
dim3(grid_size), nrepeat,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.a_grid_desc_k0_m_k1_, arg.p_c_grid_,
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, arg.b_grid_desc_k0_n_k1_,
arg.a_element_op_, arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.b_element_op_, arg.a_element_op_,
arg.c_element_op_, arg.b_element_op_,
arg.block_2_ctile_map_); arg.c_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -355,7 +358,6 @@ struct DeviceGemmShuffleXdl ...@@ -355,7 +358,6 @@ struct DeviceGemmShuffleXdl
} }
}; };
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
...@@ -438,7 +440,6 @@ struct DeviceGemmShuffleXdl ...@@ -438,7 +440,6 @@ struct DeviceGemmShuffleXdl
c_element_op); c_element_op);
} }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment