Commit 23ce8e68 authored by wangshaojie6's avatar wangshaojie6
Browse files

add prefetch 3 for pipeline v2

parent 56598b1b
......@@ -162,6 +162,10 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::cout << "a device buf: " << a_m_k_device_buf.GetDeviceBuffer() << std::endl;
std::cout << "b device buf: " << b_k_n_device_buf.GetDeviceBuffer() << std::endl;
std::cout << "c device buf: " << c_m_n_device_buf.GetDeviceBuffer() << std::endl;
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
......
......@@ -286,5 +286,419 @@ struct GridwiseGemmPipeline_v2<2>
}
};
// 3-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<3>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop > 3;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 3;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_for<0, 3, 1>{}([&](auto i_pre){
// global read i_pre
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
// move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// Initialize C
c_thread_buf.Clear();
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
static_for<0, 3, 1>{}([&](auto i_main){
// LDS write i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_main>{});
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_main>{});
// LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_main>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_main>{});
// move to i_main + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
// GEMM i_main
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
i += 3;
} while(i < (num_loop - 3));
}
// tail
if (i == num_loop - 3)
{
static_for<0, I3, 1>{}([&](auto i_res){
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 2)
{
static_for<0, I2, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 1)
{
static_for<0, I1, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
}
};
// 4-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<4>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop > 4;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 4;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// move to 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I2);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I2);
// move to 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I3);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I3);
// Initialize C
c_thread_buf.Clear();
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
// move to i + 4
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i + 4
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i + 4
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 5
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
// global read i + 5
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// global read i + 5
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
block_sync_lds();
// GEMM i + 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 6
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I2);
// global read i + 6
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I2);
// LDS write i + 2
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I2);
// global read i + 6
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I2);
block_sync_lds();
// GEMM i + 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 7
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I3);
// global read i + 7
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I3);
// LDS write i + 3
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I3);
// global read i + 7
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I3);
block_sync_lds();
// GEMM i + 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
i += 4;
} while(i < (num_loop - 4));
}
// tail
if (i == num_loop - 4)
{
static_for<0, I4, 1>{}([&](auto i_res){
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
if (i == num_loop - 3)
{
static_for<0, I3, 1>{}([&](auto i_res){
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 2)
{
static_for<0, I2, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 1)
{
static_for<0, I1, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
}
};
} // namespace ck
......@@ -111,7 +111,7 @@ template <index_t BlockSize,
index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t NumGemmKPrefetchStage = 2>
index_t NumGemmKPrefetchStage = 3>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static constexpr auto I0 = Number<0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment