Commit ece8f9b8 authored by aska-0096's avatar aska-0096
Browse files

add ldsbypass gridwise pipeline infra

parent a59e8d48
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -18,14 +16,16 @@ enum struct PipelineVersion ...@@ -18,14 +16,16 @@ enum struct PipelineVersion
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1, index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default,
bool AEnableLds = true,
bool BEnableLds = true>
constexpr auto GridwiseGemmPipeline_Selector() constexpr auto GridwiseGemmPipeline_Selector()
{ {
if constexpr(PipelineVer == PipelineVersion::v1) if constexpr(PipelineVer == PipelineVersion::v1)
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return GridwiseGemmPipeline_v1<NumPrefetch>{}; return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
......
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
namespace ck { namespace ck {
template <index_t NumPrefetch> template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1; struct GridwiseGemmPipeline_v1;
// 1-stage prefetch // 1-stage prefetch
template <> template <>
struct GridwiseGemmPipeline_v1<1> struct GridwiseGemmPipeline_v1<1, true, true>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch // 2-stage prefetch
template <> template <>
struct GridwiseGemmPipeline_v1<2> struct GridwiseGemmPipeline_v1<2, true, true>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -253,6 +253,303 @@ struct GridwiseGemmPipeline_v1<2> ...@@ -253,6 +253,303 @@ struct GridwiseGemmPipeline_v1<2>
} }
}; };
template <>
struct GridwiseGemmPipeline_v1<1, false, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <>
struct GridwiseGemmPipeline_v1<1, false, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_buf = a_block_buf_switch;
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch> template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1; struct GridwiseGemmPipelineInterwave_v1;
...@@ -348,7 +645,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> ...@@ -348,7 +645,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler // Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <> template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
{ {
}; };
...@@ -358,7 +655,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector() ...@@ -358,7 +655,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return GridwiseGemmPipeline_v1<NumPrefetch>{}; return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment