Commit 20718690 authored by Chao Liu's avatar Chao Liu
Browse files

adding gemm pipeline

parent 18707866
...@@ -11,9 +11,10 @@ ...@@ -11,9 +11,10 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" //#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" //#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp" //#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; //< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 1-stage prefetch
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 512
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 1
#endif #endif
......
...@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
index_t num_loop) index_t num_loop)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1 // move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0 // LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1 // global Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0 // LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1 // global Read 1
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// main body // main body
// FIXME: HasMainLoop = (num_loop) > 2 // FIXME: HasMainLoop = (num_loop) > 2
...@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds(); block_sync_lds();
// move to i + 2 // move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1 // LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2 // global read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1 // LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2 // global read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i; ++i;
} while(i < (num_loop - 2)); } while(i < (num_loop - 2));
...@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds(); block_sync_lds();
// LDS write num_loop - 1 // LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_block_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_block_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds(); block_sync_lds();
...@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf, static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
...@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
...@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 2 // GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
...@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 1 // GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
static __device__ void Run(const AGridDesc& a_grid_desc, static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
if(ABBlockTransferThreadGroup::IsBelong()) if(ABBlockTransferThreadGroup::IsBelong())
{ {
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1, RunABBlockTransferPipeline(a_grid_desc,
a_block_desc_ak0_m_ak1, a_block_desc,
a_blockwise_copy, a_block_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc,
b_block_desc_bk0_n_bk1, b_block_desc,
b_blockwise_copy, b_block_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_copy_step,
num_loop); num_loop);
} }
else if(BlockGemmThreadGroup::IsBelong()) else if(BlockGemmThreadGroup::IsBelong())
{ {
gridwise_gemm_pipeline.RunBlockGemmPipeline( RunBlockGemmPipeline(a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
} }
} }
}; };
......
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace ck { namespace ck {
...@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using ThisThreadBlock = using ThisThreadBlock =
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>; AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 1
using ABBlockTransferThreadGroup = ThisThreadBlock;
using BlockGemmThreadGroup = ThisThreadBlock;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#else
struct ABBlockTransferThreadGroup struct ABBlockTransferThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread()
...@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
}; };
using CShuffleBlockTransferThreadGroup = ThisThreadBlock; using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
#if 1 #if 0
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>, GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
...@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup CShuffleBlockTransferThreadGroup, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
......
...@@ -24,38 +24,38 @@ include_directories(BEFORE ...@@ -24,38 +24,38 @@ include_directories(BEFORE
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp # src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp # src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp # src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp # src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp # src/profile_batched_gemm.cpp
src/profile_conv_fwd.cpp # src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu.cpp # src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp # src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp # src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_bwd_data.cpp # src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp # src/profile_reduce.cpp
src/profile_grouped_gemm.cpp # src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp # src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp # src/profile_batched_gemm_reduce.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) #target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
...@@ -26,70 +26,70 @@ int main(int argc, char* argv[]) ...@@ -26,70 +26,70 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
else if(strcmp(argv[1], "gemm_bias_2d") == 0) // else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{ // {
return profile_gemm_bias_2d(argc, argv); // return profile_gemm_bias_2d(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_bias_relu") == 0) // else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{ // {
return profile_gemm_bias_relu(argc, argv); // return profile_gemm_bias_relu(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) // else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{ // {
return profile_gemm_bias_relu_add(argc, argv); // return profile_gemm_bias_relu_add(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_reduce") == 0) // else if(strcmp(argv[1], "gemm_reduce") == 0)
{ // {
return profile_gemm_reduce(argc, argv); // return profile_gemm_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "batched_gemm") == 0) // else if(strcmp(argv[1], "batched_gemm") == 0)
{ // {
return profile_batched_gemm(argc, argv); // return profile_batched_gemm(argc, argv);
} // }
else if(strcmp(argv[1], "batched_gemm_reduce") == 0) // else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{ // {
return profile_batched_gemm_reduce(argc, argv); // return profile_batched_gemm_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "grouped_gemm") == 0) // else if(strcmp(argv[1], "grouped_gemm") == 0)
{ // {
profile_grouped_gemm(argc, argv); // profile_grouped_gemm(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd") == 0) // else if(strcmp(argv[1], "conv_fwd") == 0)
{ // {
return profile_conv_fwd(argc, argv); // return profile_conv_fwd(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{ // {
return profile_conv_fwd_bias_relu(argc, argv); // return profile_conv_fwd_bias_relu(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{ // {
return profile_conv_fwd_bias_relu_add(argc, argv); // return profile_conv_fwd_bias_relu_add(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{ // {
return profile_conv_fwd_bias_relu_atomic_add(argc, argv); // return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
} // }
else if(strcmp(argv[1], "conv1d_bwd_data") == 0) // else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 1); // return profile_convnd_bwd_data(argc, argv, 1);
} // }
else if(strcmp(argv[1], "conv2d_bwd_data") == 0) // else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 2); // return profile_convnd_bwd_data(argc, argv, 2);
} // }
else if(strcmp(argv[1], "conv3d_bwd_data") == 0) // else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 3); // return profile_convnd_bwd_data(argc, argv, 3);
} // }
else if(strcmp(argv[1], "reduce") == 0) // else if(strcmp(argv[1], "reduce") == 0)
{ // {
return profile_reduce(argc, argv); // return profile_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) // else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{ // {
return profile_conv_bwd_weight(argc, argv); // return profile_conv_bwd_weight(argc, argv);
} // }
else else
{ {
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment