// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" template using S = ck::Sequence; static inline constexpr ck::index_t NDimSpatial = 2; using FP32 = float; struct ExecutionConfig final { bool do_verification = true; int init_method = 1; bool time_kernel = true; }; #define DefaultConvParams \ ck::utils::conv::ConvParam \ { \ NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ } inline void print_help_msg() { std::cerr << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg3: time kernel (0=no, 1=yes)\n" << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, ck::utils::conv::ConvParam& conv_params) { constexpr int num_execution_config_args = 3; // arguments for do_verification, init_method, time_kernel constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; constexpr int threshold_to_catch_all_args = threshold_to_catch_partial_args + num_conv_param_leading_args; if(argc == 1) { // use default config = ExecutionConfig{}; } // catch only ExecutionConfig arguments else if(argc == threshold_to_catch_partial_args) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } // catch both ExecutionConfig & ConvParam arguments else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_params = ck::utils::conv::parse_conv_param( num_dim_spatial, threshold_to_catch_partial_args, argv); } else { print_help_msg(); return false; } return true; }