Commit 20718690 authored by Chao Liu's avatar Chao Liu
Browse files

adding gemm pipeline

parent 18707866
......@@ -11,9 +11,10 @@
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
//#include "device_gemm_xdl.hpp"
//#include "device_gemm_xdl_c_shuffle.hpp"
//#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
......@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 1-stage prefetch
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
#endif
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -14,7 +14,7 @@
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MAX_THREAD_PER_BLOCK 512
#define CK_MIN_BLOCK_PER_CU 1
#endif
......
......@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
// FIXME: HasMainLoop = (num_loop) > 2
......@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds();
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
......@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds();
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
......@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& blockwise_gemm,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
......@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
......@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
......@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
num_loop);
RunABBlockTransferPipeline(a_grid_desc,
a_block_desc,
a_block_copy,
a_grid_buf,
a_block_buf,
a_block_copy_step,
b_grid_desc,
b_block_desc,
b_block_copy,
b_grid_buf,
b_block_buf,
b_block_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
gridwise_gemm_pipeline.RunBlockGemmPipeline(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
RunBlockGemmPipeline(a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
}
}
};
......
......@@ -4,12 +4,11 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace ck {
......@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using ThisThreadBlock =
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 1
using ABBlockTransferThreadGroup = ThisThreadBlock;
using BlockGemmThreadGroup = ThisThreadBlock;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#else
struct ABBlockTransferThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
......@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
#if 1
#if 0
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
......@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
CShuffleBlockTransferThreadGroup, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
......
......@@ -24,38 +24,38 @@ include_directories(BEFORE
set(PROFILER_SOURCE
src/profiler.cpp
src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp
# src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu_add.cpp
# src/profile_gemm_reduce.cpp
# src/profile_batched_gemm.cpp
# src/profile_conv_fwd.cpp
# src/profile_conv_fwd_bias_relu.cpp
# src/profile_conv_fwd_bias_relu_add.cpp
# src/profile_conv_fwd_bias_relu_atomic_add.cpp
# src/profile_convnd_bwd_data.cpp
# src/profile_reduce.cpp
# src/profile_grouped_gemm.cpp
# src/profile_conv_bwd_weight.cpp
# src/profile_batched_gemm_reduce.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
#target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
......@@ -26,70 +26,70 @@ int main(int argc, char* argv[])
{
return profile_gemm(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
return profile_gemm_bias_2d(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{
return profile_gemm_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{
return profile_gemm_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
return profile_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0)
{
return profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{
return profile_batched_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return profile_conv_fwd(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
return profile_conv_fwd_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{
return profile_conv_fwd_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
}
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 1);
}
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 2);
}
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 3);
}
else if(strcmp(argv[1], "reduce") == 0)
{
return profile_reduce(argc, argv);
}
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{
return profile_conv_bwd_weight(argc, argv);
}
// else if(strcmp(argv[1], "gemm_bias_2d") == 0)
// {
// return profile_gemm_bias_2d(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_bias_relu") == 0)
// {
// return profile_gemm_bias_relu(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
// {
// return profile_gemm_bias_relu_add(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_reduce") == 0)
// {
// return profile_gemm_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "batched_gemm") == 0)
// {
// return profile_batched_gemm(argc, argv);
// }
// else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
// {
// return profile_batched_gemm_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "grouped_gemm") == 0)
// {
// profile_grouped_gemm(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd") == 0)
// {
// return profile_conv_fwd(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
// {
// return profile_conv_fwd_bias_relu(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
// {
// return profile_conv_fwd_bias_relu_add(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
// {
// return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
// }
// else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 1);
// }
// else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 2);
// }
// else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 3);
// }
// else if(strcmp(argv[1], "reduce") == 0)
// {
// return profile_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
// {
// return profile_conv_bwd_weight(argc, argv);
// }
else
{
// clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment