Commit a3b4c5cb authored by wangshaojie6's avatar wangshaojie6
Browse files

merge develop branch and add gridwise pipeline v3

parents 48918ab9 1677cf70
...@@ -13,84 +13,106 @@ ...@@ -13,84 +13,106 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
template <ck::index_t... Is> struct RequantReluRequant
using S = ck::Sequence<Is...>; {
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
using F32 = float; __host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
}
using Row = ck::tensor_layout::gemm::RowMajor; float scaleGemm_;
using Col = ck::tensor_layout::gemm::ColumnMajor; float scaleRelu_;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout,
BDataType, // BDataType BLayout, // typename BLayout,
CDataType, // CDataType CLayout, // typename CLayout,
AccDataType, // AccDataType ADataType, // typename ADataType,
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType,
ALayout, // ALayout CDataType, // typename CDataType,
BLayout, // BLayout AccDataType, // typename GemmAccDataType,
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType,
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation,
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // CElementwiseOperation RequantReluRequant, // typename CElementwiseOperation,
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec,
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage,
128, // NPerBlock 256, // index_t BlockSize,
64, // KPerBlock 256, // index_t MPerBlock,
16, // AK1 128, // index_t NPerBlock,
16, // BK1 64, // index_t KPerBlock,
32, // MPerXDL 16, // index_t AK1,
32, // NPerXDL 16, // index_t BK1,
4, // MXdlPerWave 32, // index_t MPerXDL,
2, // NXdlPerWave 32, // index_t NPerXDL,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim,
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // bool ABlockLdsExtraM,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim,
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector,
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // CShuffleNXdlPerWavePerShuffle 1, // bool BBlockLdsExtraN,
S<1, 1, 64, 1, 1, 4>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle,
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>; BDataType,
CDataType,
float,
PassThrough,
PassThrough,
RequantReluRequant>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -108,13 +130,13 @@ int main(int argc, char* argv[]) ...@@ -108,13 +130,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -128,7 +150,7 @@ int main(int argc, char* argv[]) ...@@ -128,7 +150,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -202,7 +224,7 @@ int main(int argc, char* argv[]) ...@@ -202,7 +224,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -227,7 +249,7 @@ int main(int argc, char* argv[]) ...@@ -227,7 +249,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp) add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp)
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util) target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util)
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
\ No newline at end of file
This diff is collapsed.
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
This diff is collapsed.
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment