Commit 4cdf78fc authored by wangshaojie6's avatar wangshaojie6
Browse files

add gridwisepipeline v2 prefetch 2

parent 0869af27
......@@ -48,7 +48,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::
// ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 8, 16, 2, 1, 1, 1, S<1, 8>, S<1, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -35,6 +35,7 @@ using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::RowMajor;
//using BLayout = Col;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -134,8 +135,8 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
......@@ -98,7 +98,7 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// experimental feature: buffer load/store/atomic-add/ OOB trick
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
......
......@@ -173,8 +173,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
"wrong!");
// TODO: remove this restriction
static_assert(BM0 == 2, "wrong");
static_assert(BM0 == 2 && BN0 == 2, "wrong");
//static_assert(BM0 == 2, "wrong");
//static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
......@@ -214,10 +214,10 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong");
//static_assert(BM0 == 2 && BN0 == 2 &&
// CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
// CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
// "wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
......
......@@ -473,7 +473,7 @@ struct DeviceGemmDl
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a")
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
......@@ -3,12 +3,17 @@
namespace ck {
struct GridwiseGemmPipeline_v2
template <index_t NumPrefetch>
struct GridwiseGemmPipeline_v2;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
return num_loop > 2;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
......@@ -121,4 +126,164 @@ struct GridwiseGemmPipeline_v2
}
};
// 2-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<2>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop > 2;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 2;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Initialize C
c_thread_buf.Clear();
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
// global read i + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// global read i + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
block_sync_lds();
// GEMM i + 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
i += 2;
} while(i < (num_loop - 2));
}
// tail
if (i > num_loop - 2)
{
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// tail
else if (i == num_loop - 2)
{
// Write num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
......@@ -131,7 +131,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
#if 0
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
using GridwiseGemmPipe = GridwiseGemmPipeline_v2<NumGemmKPrefetchStage>;
#endif
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment