From 0ee3aea16af66fd33282ce7a505533377fb3a74f Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 26 Oct 2022 09:25:27 -0700 Subject: [PATCH 01/80] fix the script parsing the QA results (#495) --- script/process_perf_data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/script/process_perf_data.py b/script/process_perf_data.py index de1703cfc..638e4ef56 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -81,7 +81,7 @@ def parse_logfile(logfile): StrideA=[] StrideB=[] StrideC=[] - if 'perf_gemm' in logfile: + if 'perf_gemm.log' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() @@ -120,14 +120,14 @@ def parse_logfile(logfile): res = [x for _,x in sorted(zip(tests,tflops))] #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] test_list=list(range(1,len(tests)+1)) - #parse conv_fwd performance tests: - elif 'conv_fwd' in logfile: + #parse conv_fwd and conv_bwd performance tests: + elif 'conv_fwd' in logfile or 'conv_bwd_data' in logfile: for line in open(logfile): if 'tflops:' in line: lst=line.split() res.append(lst[1]) #parse all other performance tests: - elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'conv_bwd_data' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile: + elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() @@ -149,7 +149,7 @@ def store_new_test_result(table_name, test_results, testlist, branch_name, node_ df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime']) df_add=pd.DataFrame(data=[test_results],columns=testlist) df=pd.concat([df,df_add],axis=1) - print("new test results dataframe:",df) + #print("new test results dataframe:",df) df.to_sql(table_name,connection,if_exists='append',index=False) return 0 @@ -165,7 +165,7 @@ def compare_test_to_baseline(baseline,test,testlist): print("test # ",i,"shows regression by {:.3f}%".format( (float(test[i])-base_list[i])/base_list[i]*100)) regression=1 - ave_perf=ave_perf+float(test[i])/base_list[i] + if base_list[i]>0: ave_perf=ave_perf+float(test[i])/base_list[i] if regression==0: print("no regressions found") ave_perf=ave_perf/len(base_list) @@ -248,7 +248,7 @@ def main(): conn = sqlEngine.connect() #save gemm performance tests: - if 'perf_gemm' in filename: + if 'perf_gemm.log' in filename: #write the ck_gemm_test_params table only needed once the test set changes #post_test_params(test_list,conn) for i in range(1,len(results)+1): -- GitLab From 57106048aeb20f55461e7c25e689aa0a945beb7a Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Fri, 28 Oct 2022 02:25:12 +0800 Subject: [PATCH 02/80] Gemm standalone bench executable (#480) * prototype 4 layouts fix default stride all problem sizes tidy move file update build script restore old file fix build * refactor standalone test to use gemm test harness * simplify gemm test * update build script * remove redundant * early return when cmd arg doesn't match * tidy * report failure when result not validated * tidy * Apply suggestions from code review Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- test/gemm/CMakeLists.txt | 10 ++ test/gemm/gemm_bf16.cpp | 57 +------ test/gemm/gemm_fp16.cpp | 57 +------ test/gemm/gemm_fp32.cpp | 57 +------ test/gemm/gemm_fp64.cpp | 57 +------ test/gemm/gemm_int8.cpp | 57 +------ test/gemm/gemm_standalone_xdl_fp16.cpp | 162 ++++++++++++++++++++ test/gemm/gemm_util.hpp | 107 ++++++++----- test/gemm/instance/gemm_f16_nn_instance.cpp | 86 +++++++++++ test/gemm/instance/gemm_f16_nn_instance.hpp | 41 +++++ test/gemm/instance/gemm_f16_nt_instance.cpp | 86 +++++++++++ test/gemm/instance/gemm_f16_nt_instance.hpp | 41 +++++ test/gemm/instance/gemm_f16_tn_instance.cpp | 86 +++++++++++ test/gemm/instance/gemm_f16_tn_instance.hpp | 41 +++++ test/gemm/instance/gemm_f16_tt_instance.cpp | 86 +++++++++++ test/gemm/instance/gemm_f16_tt_instance.hpp | 41 +++++ test/gemm/run_gemm_test.inc | 41 +++++ 17 files changed, 816 insertions(+), 297 deletions(-) create mode 100644 test/gemm/gemm_standalone_xdl_fp16.cpp create mode 100644 test/gemm/instance/gemm_f16_nn_instance.cpp create mode 100644 test/gemm/instance/gemm_f16_nn_instance.hpp create mode 100644 test/gemm/instance/gemm_f16_nt_instance.cpp create mode 100644 test/gemm/instance/gemm_f16_nt_instance.hpp create mode 100644 test/gemm/instance/gemm_f16_tn_instance.cpp create mode 100644 test/gemm/instance/gemm_f16_tn_instance.hpp create mode 100644 test/gemm/instance/gemm_f16_tt_instance.cpp create mode 100644 test/gemm/instance/gemm_f16_tt_instance.hpp create mode 100644 test/gemm/run_gemm_test.inc diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 8069dac15..c427586bb 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) add_test_executable(test_gemm_int8 gemm_int8.cpp) target_link_libraries(test_gemm_int8 PRIVATE utility) target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) + +add_library(gemm_standalone_xdl_fp16_instances STATIC + instance/gemm_f16_nn_instance.cpp + instance/gemm_f16_nt_instance.cpp + instance/gemm_f16_tn_instance.cpp + instance/gemm_f16_tt_instance.cpp +) +add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) +target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) +target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 6130ec9bc..5290d4663 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -24,56 +24,11 @@ #include "test/gemm/gemm_util.hpp" -int main() -{ - using ADataType = ck::bhalf_t; - using BDataType = ck::bhalf_t; - using CDataType = ck::bhalf_t; - using AccDataType = float; +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; +#include "run_gemm_test.inc" - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool pass = true; - - using DeviceOp = ck::tensor_operation::device::DeviceGemm; - - const auto gemmPtrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemmPtr : gemmPtrs) - { - pass &= ck::gemm_util::TestGemm, - ADataType, - BDataType, - CDataType, - AccDataType, - decltype(a_layout), - decltype(b_layout), - decltype(c_layout), - PassThrough, - PassThrough, - PassThrough>{}(gemmPtr); - } - - return pass; - }; - - bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && - test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); - - std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - return pass ? 0 : 1; -} +int main() { return run_gemm_test(); } diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index 05e696cad..92e225def 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -24,56 +24,11 @@ #include "test/gemm/gemm_util.hpp" -int main() -{ - using ADataType = ck::half_t; - using BDataType = ck::half_t; - using CDataType = ck::half_t; - using AccDataType = float; +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; +#include "run_gemm_test.inc" - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool pass = true; - - using DeviceOp = ck::tensor_operation::device::DeviceGemm; - - const auto gemmPtrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemmPtr : gemmPtrs) - { - pass &= ck::gemm_util::TestGemm, - ADataType, - BDataType, - CDataType, - AccDataType, - decltype(a_layout), - decltype(b_layout), - decltype(c_layout), - PassThrough, - PassThrough, - PassThrough>{}(gemmPtr); - } - - return pass; - }; - - bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && - test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); - - std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - return pass ? 0 : 1; -} +int main() { return run_gemm_test(); } diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 3e141d7b3..5d8c4881b 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -24,56 +24,11 @@ #include "test/gemm/gemm_util.hpp" -int main() -{ - using ADataType = float; - using BDataType = float; - using CDataType = float; - using AccDataType = float; +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; +#include "run_gemm_test.inc" - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool pass = true; - - using DeviceOp = ck::tensor_operation::device::DeviceGemm; - - const auto gemmPtrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemmPtr : gemmPtrs) - { - pass &= ck::gemm_util::TestGemm, - ADataType, - BDataType, - CDataType, - AccDataType, - decltype(a_layout), - decltype(b_layout), - decltype(c_layout), - PassThrough, - PassThrough, - PassThrough>{}(gemmPtr); - } - - return pass; - }; - - bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && - test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); - - std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - return pass ? 0 : 1; -} +int main() { return run_gemm_test(); } diff --git a/test/gemm/gemm_fp64.cpp b/test/gemm/gemm_fp64.cpp index 96dc459a3..85d7f95bf 100644 --- a/test/gemm/gemm_fp64.cpp +++ b/test/gemm/gemm_fp64.cpp @@ -24,56 +24,11 @@ #include "test/gemm/gemm_util.hpp" -int main() -{ - using ADataType = double; - using BDataType = double; - using CDataType = double; - using AccDataType = double; +using ADataType = double; +using BDataType = double; +using CDataType = double; +using AccDataType = double; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; +#include "run_gemm_test.inc" - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool pass = true; - - using DeviceOp = ck::tensor_operation::device::DeviceGemm; - - const auto gemmPtrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemmPtr : gemmPtrs) - { - pass &= ck::gemm_util::TestGemm, - ADataType, - BDataType, - CDataType, - AccDataType, - decltype(a_layout), - decltype(b_layout), - decltype(c_layout), - PassThrough, - PassThrough, - PassThrough>{}(gemmPtr); - } - - return pass; - }; - - bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && - test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); - - std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - return pass ? 0 : 1; -} +int main() { return run_gemm_test(); } diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index c7d79782a..e73b22ce9 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -24,56 +24,11 @@ #include "test/gemm/gemm_util.hpp" -int main() -{ - using ADataType = int8_t; - using BDataType = int8_t; - using CDataType = int8_t; - using AccDataType = int32_t; +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; +#include "run_gemm_test.inc" - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool pass = true; - - using DeviceOp = ck::tensor_operation::device::DeviceGemm; - - const auto gemmPtrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemmPtr : gemmPtrs) - { - pass &= ck::gemm_util::TestGemm, - ADataType, - BDataType, - CDataType, - AccDataType, - decltype(a_layout), - decltype(b_layout), - decltype(c_layout), - PassThrough, - PassThrough, - PassThrough>{}(gemmPtr); - } - - return pass; - }; - - bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && - test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); - - std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - return pass ? 0 : 1; -} +int main() { return run_gemm_test(); } diff --git a/test/gemm/gemm_standalone_xdl_fp16.cpp b/test/gemm/gemm_standalone_xdl_fp16.cpp new file mode 100644 index 000000000..8f5a5c557 --- /dev/null +++ b/test/gemm/gemm_standalone_xdl_fp16.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_util.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +#include "gemm_f16_nn_instance.hpp" +#include "gemm_f16_nt_instance.hpp" +#include "gemm_f16_tn_instance.hpp" +#include "gemm_f16_tt_instance.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using F16 = ck::half_t; +using ADataType = F16; +using BDataType = F16; +using AccDataType = float; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ck::gemm_util::GemmParams; +using ck::tensor_operation::device::BaseOperator; +using ck::tensor_operation::device::DeviceGemm; +using namespace ck::tensor_operation::device::instance; + +using DeviceGemmNN = + DeviceGemm; +using DeviceGemmNT = + DeviceGemm; +using DeviceGemmTN = + DeviceGemm; +using DeviceGemmTT = + DeviceGemm; + +struct LayoutConfig +{ + bool ARowMajor; + bool BRowMajor; + bool CRowMajor; +}; + +int main(int argc, char* argv[]) +{ + // Class DeviceGemm is templated by layout and precision types so it is not an option to contain + // them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it + // upon invocation. + // And since DeviceGemm does not expose template arg information, an extra book keeping class + // LayoutConfig is used for determining which type a BaseOperator instance should be cast to. + using OpFactoryFn = void (*)(std::vector>&); + + std::vector> problems = { + // clang-format off + // 104 tiles + {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, + // 110 tiles + {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, + // clang-format on + }; + + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: time kernel (0=no, 1=yes)" << std::endl; + return 0; + } + + bool pass = true; + for(auto& p : problems) + { + GemmParams& problem_size = std::get<0>(p); + const LayoutConfig& layout_config = std::get<1>(p); + const auto& factory = std::get<2>(p); + std::vector> ops; + factory(ops); + + // overwrite strides + problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M; + problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K; + problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M; + + if(!layout_config.ARowMajor && !layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(!layout_config.ARowMajor && layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(layout_config.ARowMajor && !layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(layout_config.ARowMajor && layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + } + + std::cout << (pass ? "ALL TESTS PASSED" : "SOME TESTS FAILED") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 2df605be1..6291215b3 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -16,21 +16,13 @@ namespace gemm_util { struct GemmParams { - GemmParams() - : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) - { - } - - ck::index_t M; - ck::index_t N; - ck::index_t K; + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; - ck::index_t StrideA; - ck::index_t StrideB; - ck::index_t StrideC; - - float alpha; - float beta; + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; }; template & C, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + bool time_kernel) { DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); @@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, { a_m_k_device_buf.ToDevice(A.mData.data()); b_k_n_device_buf.ToDevice(B.mData.data()); - invoker_ptr->Run(argument_ptr.get()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * params.M * params.N * params.K; + std::size_t num_btype = sizeof(ADataType) * params.M * params.K + + sizeof(BDataType) * params.K * params.N + + sizeof(CDataType) * params.M * params.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << std::endl; + c_m_n_device_buf.FromDevice(C.mData.data()); return true; @@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, } } -template +template struct TestGemm { + template auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) { auto f_host_tensor_descriptor = @@ -156,25 +158,42 @@ struct TestGemm f_generate_tensor_value(a_m_k, ADataType{}); f_generate_tensor_value(b_k_n, BDataType{}); + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); } - auto operator()(const DeviceGemmPtr_& gemmPtr) + template