Commit 1efaa52c authored by ramjana's avatar ramjana
Browse files

GEMM pipeline update for wavelet progrmmaing model

parent 712e464c
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5 ./bin/example_gemm_xdl 0 1 5
#arg10 : gemm_pipeline (0=default, 1=waveletmodel)
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -49,6 +50,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -49,6 +50,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
using DeviceGemmInstance_WaveletModel = ck::tensor_operation::device::DeviceGemm_Xdl_WaveletModel_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -149,6 +158,8 @@ int main(int argc, char* argv[]) ...@@ -149,6 +158,8 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
//replace DeviceGemmInstance_WaveletModel for wavelet gemm pipeline
//auto gemm = DeviceGemmInstance_WaveletModel{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
......
#pragma once
#include "common_header.hpp"
namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave;
//1-stage prefetch
template<typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
if constexpr(HasMainLoop)
{
index_t i=0;
do
{
//sync for Load threads()
block_sync_lds();
//?? what is this for
// sync with math threads()
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
//what is this for??
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
// GEMM num_loop
}
}
};
template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave;
// 1- stage prefetch
template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment