// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp" using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvBwdWeightDefault = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; template struct CommonLayoutSetting { using InputLayout = InputLay; using WeightLayout = WeightLay; using OutputLayout = OutputLay; }; template struct CommonLayoutSettingSelector; namespace ctl = ck::tensor_layout::convolution; template <> struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting { }; template <> struct CommonLayoutSettingSelector<2> final : CommonLayoutSetting { }; template <> struct CommonLayoutSettingSelector<3> final : CommonLayoutSetting { }; template using InputLayout = typename CommonLayoutSettingSelector::InputLayout; template using WeightLayout = typename CommonLayoutSettingSelector::WeightLayout; template using OutputLayout = typename CommonLayoutSettingSelector::OutputLayout; struct ExecutionConfig final { bool do_verification = true; int init_method = 1; bool time_kernel = false; }; #define DefaultConvParam \ ck::utils::conv::ConvParam \ { \ 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ } inline void print_help_msg() { std::cerr << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg3: time kernel (0=no, 1=yes)\n" << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, ck::utils::conv::ConvParam& conv_param) { constexpr int num_execution_config_args = 3; // arguments for do_verification, init_method, time_kernel constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; constexpr int threshold_to_catch_all_args = threshold_to_catch_partial_args + num_conv_param_leading_args; if(argc == 1) { // use default } // catch only ExecutionConfig arguments else if(argc == threshold_to_catch_partial_args) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } // catch both ExecutionConfig & ConvParam arguments else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( num_dim_spatial, threshold_to_catch_partial_args, argv); } else { print_help_msg(); return false; } return true; }