Commit 702c3379 authored by root's avatar root
Browse files

fixed clang format errors

parent 599497b0
...@@ -39,7 +39,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -39,7 +39,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
......
...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
//static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), // static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// "wrong! ThreadGroup::GetNumOfThread() too small"); // "wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp" #include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config= StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -485,11 +484,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -485,11 +484,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize), dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
...@@ -516,8 +515,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -516,8 +515,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize), dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
...@@ -539,7 +538,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -539,7 +538,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
......
...@@ -7,22 +7,22 @@ namespace ck { ...@@ -7,22 +7,22 @@ namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage> template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave; struct GridwiseGemmLoadWave;
//1-stage prefetch // 1-stage prefetch
template<typename TileLoadThreadGroup> template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{ {
// TODO: improve applicability // TODO: improve applicability
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
typename AGridDesc, typename AGridDesc,
typename ABlockDesc, typename ABlockDesc,
typename ABlockTransfer, typename ABlockTransfer,
...@@ -36,43 +36,43 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -36,43 +36,43 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep> typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc, static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
index_t num_loop) index_t num_loop)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1 // move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0 // LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
index_t i=0; index_t i = 0;
do do
{ {
//sync for Load threads() // sync for Load threads()
block_sync_lds(); block_sync_lds();
// global read i + 1 // global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2 // move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -81,10 +81,9 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -81,10 +81,9 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// sync with math threads() // sync with math threads()
block_sync_lds(); block_sync_lds();
//LDS write i+1 // LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i; ++i;
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
...@@ -92,12 +91,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -92,12 +91,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
// GEMM num_loop
}
// GEMM num_loop
}
} }
}; };
...@@ -105,29 +102,26 @@ template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage> ...@@ -105,29 +102,26 @@ template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave; struct GridwiseGemmMathWave;
// 1- stage prefetch // 1- stage prefetch
template <typename TileMathThreadGroup> template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1> struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
{
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
typename ABlockBuffer, typename ABlockBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BlockwiseGemm, typename BlockwiseGemm,
typename CThreadBuffer> typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf, static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm, const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1> ...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
// GEMM num_loop - 1 // GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
}; };
......
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_waveletmodel_cshuffle( kernel_gemm_xdl_waveletmodel_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -40,18 +40,18 @@ __global__ void ...@@ -40,18 +40,18 @@ __global__ void
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -121,64 +121,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -121,64 +121,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
struct TileLoadThreadGroup struct TileLoadThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
{ __device__ static constexpr bool IsBelong()
return TileLoadThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{ {
return (get_thread_local_1d_id() >= TileLoadThreadGroupSize); return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id() - TileMathThreadGroupSize; } __device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - TileMathThreadGroupSize;
}
}; };
struct TileMathThreadGroup struct TileMathThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
{ __device__ static constexpr bool IsBelong()
return TileMathThreadGroupSize;
}
__device__ static constexpr bool IsBelong()
{ {
return get_thread_local_1d_id() < TileMathThreadGroupSize; return get_thread_local_1d_id() < TileMathThreadGroupSize;
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
}; };
using CShuffleBlockTransferThreadGroup = using CShuffleBlockTransferThreadGroup = ThisThreadBlock<TileMathThreadGroupSize>;
ThisThreadBlock<TileMathThreadGroupSize>; // load and math+store Wave pipelines.
//load and math+store Wave pipelines. // TODO: build pipelines blocks scheduling parallel tasks
//TODO: build pipelines blocks scheduling parallel tasks using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup,NumGemmKPrefetchStage>; using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup,NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1), make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0,Number<NPerBlock>{},BK1), make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock+BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -196,7 +189,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -196,7 +189,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -350,17 +343,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -350,17 +343,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
// build loadWave and MathWave pipelines // build loadWave and MathWave pipelines
// loadWave and MathWave synchronized through LDS // loadWave and MathWave synchronized through LDS
// //
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
...@@ -392,344 +384,345 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -392,344 +384,345 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
if (TileLoadThreadGroup::IsBelong()) if(TileLoadThreadGroup::IsBelong())
{ {
//LoadWave // LoadWave
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup, ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), a_blockwise_copy,
ck::tensor_operation::element_wise::PassThrough{}); a_grid_buf,
a_block_buf,
// B matrix blockwise copy a_block_slice_copy_step,
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), b_blockwise_copy,
ck::tensor_operation::element_wise::PassThrough{}); b_grid_buf,
b_block_buf,
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, b_block_slice_copy_step,
a_block_desc_ak0_m_ak1, num_k_block_main_loop);
a_blockwise_copy,
a_grid_buf, block_sync_lds();
a_block_buf, block_sync_lds();
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
num_k_block_main_loop);
block_sync_lds();
block_sync_lds();
} }
else if (TileMathThreadGroup::IsBelong()) else if(TileMathThreadGroup::IsBelong())
{ {
//branch early for math wave // branch early for math wave
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<TileMathThreadGroupSize, auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
FloatAB, TileMathThreadGroupSize,
FloatGemmAcc, FloatAB,
decltype(a_block_desc_ak0_m_ak1), FloatGemmAcc,
decltype(b_block_desc_bk0_n_bk1), decltype(a_block_desc_ak0_m_ak1),
MPerXdl, decltype(b_block_desc_bk0_n_bk1),
NPerXdl, MPerXdl,
MXdlPerWave, NPerXdl,
NXdlPerWave, MXdlPerWave,
KPack>{}; NXdlPerWave,
KPack>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(a_block_buf, // TODO re-architect LDS+math stages
b_block_buf, GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
blockwise_gemm, a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
c_thread_buf,
num_k_block_main_loop); // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// GEMM definition // a_mtx[K0PerBlock, MPerBlock] is in LDS
// c_mtx += transpose(a_mtx) * b_mtx // b_mtx[K0PerBlock, NPerBlock] is in LDS
// a_mtx[K0PerBlock, MPerBlock] is in LDS // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// b_mtx[K0PerBlock, NPerBlock] is in LDS // register
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // sanity check
// register
// sanity check // shuffle C and write out
{
// shuffle C and write out static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
{ NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && "wrong!");
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// TODO: hacky, fix it! blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); // TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// TODO: hacky, fix it! constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( make_tuple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_freeze_transform(I0),
make_tuple( make_unmerge_transform(make_tuple(
make_freeze_transform(I0), Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
make_unmerge_transform(make_tuple( M1, // M1 = MWave
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle M2, // M2 * M3 * M4 = MPerXdl
M1, // M1 = MWave M3,
M2, // M2 * M3 * M4 = MPerXdl M4)),
M3, make_freeze_transform(I0),
M4)), make_unmerge_transform(make_tuple(
make_freeze_transform(I0), Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
make_unmerge_transform(make_tuple( N1, // N1 = NWave
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle N2))), // N2 = NPerXdl
N1, // N1 = NWave make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
N2))), // N2 = NPerXdl make_tuple(Sequence<>{},
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), Sequence<0, 2, 4, 5, 6>{},
make_tuple( Sequence<>{},
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block)); make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc, FloatGemmAcc,
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
I1, I1,
I1, I1,
M2, M2,
I1, I1,
M4, M4,
I1>, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0,
make_multi_index(0, 0,
0, m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I2], m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I3], m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I4], n_thread_data_on_block_idx[I2]),
n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}};
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< CShuffleBlockTransferThreadGroup, // ThreadGroup
CShuffleBlockTransferThreadGroup, // ThreadGroup CElementwiseOperation, // ElementwiseOperation,
CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp,
CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1,
Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1,
1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData,
FloatCShuffle, // typename SrcData, FloatC, // typename DstData,
FloatC, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, 3, // index_t VectorDim,
3, // index_t VectorDim, CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, false> // bool ThreadTransferDstResetCoordinateAfterRun>
false> // bool ThreadTransferDstResetCoordinateAfterRun> {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0),
make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op};
c_element_op};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr =
constexpr auto sfc_c_vgpr = SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<CShuffleMXdlPerWavePerShuffle,
Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, 1,
1, 1,
1, M2,
M2, 1,
1, M4,
M4, 1>>{};
1>>{};
// space filling curve for shuffled blockwise C in global mem
// space filling curve for shuffled blockwise C in global mem constexpr auto sfc_c_global =
constexpr auto sfc_c_global = SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, Sequence<0, 2, 1, 3>,
Sequence<0, 2, 1, 3>, Sequence<1,
Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1,
1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); // TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache
//TODO // lines
// 1. we do not need to do LDS swizzle to align global writes writing cache lines // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN elments (N is vertical or strided dimension) // elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat,
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 elments (M is coalescing dimension) // cmat - c-mat register layout are Mx1 elments (M is coalescing
// by enumerating M index in amat, bmat you can align cmat register(s) to contiguous M elements // dimension) by enumerating M index in amat, bmat you can align cmat
// for example // register(s) to contiguous M elements for example
// 1st mfma instruction output space : 0 4 8 12 16 .... // 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 .... // 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 .... // 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 .... // 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write (no LDS swizzling required) // you can pack 4 registers output space into 2WORD and do global write
// 2. avoid using s_barrier in this case where not all 256 threads required to swizzle c layout // (no LDS swizzling required)
// 2. avoid using s_barrier in this case where not all 256 threads required to
static_for<0, num_access, 1>{}([&](auto access_id) { // swizzle c layout
// make sure it's safe to write to LDS
block_sync_lds(); static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
// each thread write its data from VGPR to LDS block_sync_lds();
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // each thread write its data from VGPR to LDS
c_thread_buf, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_shuffle_block_buf); c_thread_buf,
// make sure it's safe to read from LDS c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_sync_lds(); c_shuffle_block_buf);
// make sure it's safe to read from LDS
// each block copy its data from LDS to global block_sync_lds();
c_shuffle_block_copy_lds_to_global.Run(
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
}
} }
}
} }
}; // GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle }; // GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} //namespace ck } // namespace ck
...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize FloatAB,
FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment