Commit 702c3379 authored by root's avatar root
Browse files

fixed clang format errors

parent 599497b0
...@@ -39,7 +39,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -39,7 +39,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
......
...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
//static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), // static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// "wrong! ThreadGroup::GetNumOfThread() too small"); // "wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp" #include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config= StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -485,11 +484,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -485,11 +484,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize), dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
...@@ -516,8 +515,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -516,8 +515,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize), dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
...@@ -539,7 +538,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -539,7 +538,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
......
...@@ -7,22 +7,22 @@ namespace ck { ...@@ -7,22 +7,22 @@ namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage> template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave; struct GridwiseGemmLoadWave;
//1-stage prefetch // 1-stage prefetch
template<typename TileLoadThreadGroup> template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{ {
// TODO: improve applicability // TODO: improve applicability
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
typename AGridDesc, typename AGridDesc,
typename ABlockDesc, typename ABlockDesc,
typename ABlockTransfer, typename ABlockTransfer,
...@@ -36,43 +36,43 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -36,43 +36,43 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep> typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc, static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
index_t num_loop) index_t num_loop)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1 // move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0 // LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
index_t i=0; index_t i = 0;
do do
{ {
//sync for Load threads() // sync for Load threads()
block_sync_lds(); block_sync_lds();
// global read i + 1 // global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2 // move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -81,10 +81,9 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -81,10 +81,9 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// sync with math threads() // sync with math threads()
block_sync_lds(); block_sync_lds();
//LDS write i+1 // LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i; ++i;
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
...@@ -92,12 +91,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -92,12 +91,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
// GEMM num_loop
}
// GEMM num_loop
}
} }
}; };
...@@ -105,29 +102,26 @@ template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage> ...@@ -105,29 +102,26 @@ template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave; struct GridwiseGemmMathWave;
// 1- stage prefetch // 1- stage prefetch
template <typename TileMathThreadGroup> template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1> struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
{
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
typename ABlockBuffer, typename ABlockBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BlockwiseGemm, typename BlockwiseGemm,
typename CThreadBuffer> typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf, static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm, const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1> ...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
// GEMM num_loop - 1 // GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
}; };
......
...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize FloatAB,
FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment