Commit f26fb605 authored by wangshaojie6's avatar wangshaojie6
Browse files

Merge branch 'develop' into bwd_weight_bf16_splitk

parents 32d06c66 1677cf70
...@@ -212,30 +212,50 @@ def runCKProfiler(Map conf=[:]){ ...@@ -212,30 +212,50 @@ def runCKProfiler(Map conf=[:]){
{ {
cmake_build(conf) cmake_build(conf)
dir("script"){ dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log" //run gemm performance tests
sh "rm -f ${perf_log}" def gemm_log = "perf_gemm_${gpu_arch}.log"
sh "echo Branch name: ${env.BRANCH_NAME} > ${perf_log}" sh "rm -f ${gemm_log}"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Node name: ${NODE_NAME} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${perf_log}" sh "echo GPU_arch: ${gpu_arch} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${perf_log}" sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log}" sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
//the script will return 0 if the performance criteria are met sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
//or return 1 if the criteria are not met sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts "${perf_log}" sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
sh "python3 parse_perf_data.py ${perf_log} " sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
archiveArtifacts "${gemm_log}"
sh "python3 parse_perf_data.py ${gemm_log} "
//run resnet50 test
def resnet_log = "perf_resnet50_${gpu_arch}.log"
sh "rm -f ${resnet_log}"
sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}"
sh "echo Node name: ${NODE_NAME} >> ${resnet_log}"
sh "echo GPU_arch: ${gpu_arch} >> ${resnet_log}"
sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}"
sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}"
//first run tests with N=256
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}"
//then run with N=4
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}"
archiveArtifacts "${resnet_log}"
//the script will put the results from N=256 and N=4 runs into separate tables
sh "python3 parse_perf_data.py ${resnet_log} "
} }
} }
} }
......
...@@ -4,4 +4,5 @@ add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) ...@@ -4,4 +4,5 @@ add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
...@@ -170,9 +170,7 @@ int main(int argc, char* argv[]) ...@@ -170,9 +170,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -169,9 +169,7 @@ int main(int argc, char* argv[]) ...@@ -169,9 +169,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -167,9 +167,7 @@ int main(int argc, char* argv[]) ...@@ -167,9 +167,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -193,9 +193,9 @@ int main(int argc, char* argv[]) ...@@ -193,9 +193,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -166,9 +166,9 @@ int main(int argc, char* argv[]) ...@@ -166,9 +166,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -21,8 +21,6 @@ template <ck::index_t... Is> ...@@ -21,8 +21,6 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F64 = double; using F64 = double;
using F32 = float;
using F16 = ck::half_t;
using ADataType = double; using ADataType = double;
using BDataType = double; using BDataType = double;
...@@ -195,9 +193,9 @@ int main(int argc, char* argv[]) ...@@ -195,9 +193,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -233,7 +231,7 @@ int main(int argc, char* argv[]) ...@@ -233,7 +231,7 @@ int main(int argc, char* argv[])
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
} }
#endif #endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -194,9 +194,9 @@ int main(int argc, char* argv[]) ...@@ -194,9 +194,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -224,10 +224,10 @@ int main(int argc, char* argv[]) ...@@ -224,10 +224,10 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break; break;
default: default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
......
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
......
...@@ -147,8 +147,6 @@ class SimpleAppArgs ...@@ -147,8 +147,6 @@ class SimpleAppArgs
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2}; const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3}; const std::vector<int> invariantDims{3};
...@@ -254,7 +252,9 @@ int main(int argc, char* argv[]) ...@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType, ReductionHost<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
......
...@@ -108,8 +108,6 @@ int main(int argc, char* argv[]) ...@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
const std::vector<size_t> outLengths = {64, 320, 80}; const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1) if(argc == 1)
{ {
do_verify = true; do_verify = true;
...@@ -191,7 +189,9 @@ int main(int argc, char* argv[]) ...@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType, ReductionHost<InOutDataType,
AccDataType, AccDataType,
InOutDataType, InOutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank 5, // Rank
2, // NumReduceDim 2, // NumReduceDim
PropagateNan, PropagateNan,
......
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"
#include "device_pool2d_fwd_nhwc_nhwc.hpp" #include "device_pool2d_fwd_nhwc_nhwc.hpp"
template <typename InDataType, template <typename InDataType,
...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/) const std::array<ck::index_t, 2>& /*in_right_pads*/)
{ {
using namespace ck::host_reduce;
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider); using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider); using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOperation::GetIdentityValue();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
}; };
...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
} }
else else
{ {
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { AccDataType,
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce, accuVal, currVal, accuIndex, currIndex);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex; out_indices(n, c, ho, wo) = accuIndex;
...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification, ...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using namespace ck::host_reduce;
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 4; int group_count = rand() % 16 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
...@@ -189,12 +189,17 @@ int main(int argc, char* argv[]) ...@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error(
......
add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_sum_squaresum_fp16 gemm_reduce_xdl_sum_squaresum_fp16.cpp) add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp)
...@@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F64; using DDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>; using DPtrsGlobal = ck::Tuple<DDataType*>;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
...@@ -52,15 +52,34 @@ static constexpr auto GemmSpecialization = ...@@ -52,15 +52,34 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, "
<< gemm_gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -193,21 +212,10 @@ int main(int argc, char* argv[]) ...@@ -193,21 +212,10 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// init D // [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest()); d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true; bool pass = true;
...@@ -228,7 +236,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +236,7 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n)); d_reduce_op(d_acc, c_m_n_host_result(m, n));
...@@ -246,5 +254,13 @@ int main(int argc, char* argv[]) ...@@ -246,5 +254,13 @@ int main(int argc, char* argv[])
1e-3); 1e-3);
} }
if(time_kernel)
{
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(
gemm_reduceMax_ave_time, M, N, K);
}
return pass ? 0 : 1; return pass ? 0 : 1;
} }
...@@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
...@@ -47,10 +47,12 @@ using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; ...@@ -47,10 +47,12 @@ using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using UnaryIdenticElementOp = using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>; ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp = using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>; ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>; using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp = using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
...@@ -61,15 +63,35 @@ static constexpr auto GemmSpecialization = ...@@ -61,15 +63,35 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, "
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -182,6 +204,9 @@ int main(int argc, char* argv[]) ...@@ -182,6 +204,9 @@ int main(int argc, char* argv[])
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer())); static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -198,8 +223,8 @@ int main(int argc, char* argv[]) ...@@ -198,8 +223,8 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
DxsInElementOp{}, dxs_in_element_op,
DxsOutElementOp{}); dxs_out_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -214,19 +239,7 @@ int main(int argc, char* argv[]) ...@@ -214,19 +239,7 @@ int main(int argc, char* argv[])
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test // will not be correct. need to set time_kernel = false for correctness test
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); invoker.Run(argument, StreamConfig{nullptr, false});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
...@@ -248,8 +261,8 @@ int main(int argc, char* argv[]) ...@@ -248,8 +261,8 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
...@@ -257,12 +270,14 @@ int main(int argc, char* argv[]) ...@@ -257,12 +270,14 @@ int main(int argc, char* argv[])
float d0_val = 0; float d0_val = 0;
float d1_val = 0; float d1_val = 0;
UnaryIdenticElementOp{}(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val); d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
} }
...@@ -282,5 +297,12 @@ int main(int argc, char* argv[]) ...@@ -282,5 +297,12 @@ int main(int argc, char* argv[])
1e-5); 1e-5);
} }
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K);
}
return pass ? 0 : 1; return pass ? 0 : 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment