// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/utility/type.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" using F16 = ck::half_t; struct ExecutionConfig final { bool do_verification = true; bool time_kernel = false; }; struct Problem final { std::array shape = {4, 16, 32, 32}; std::array axes = {0, 2, 3, 1}; }; template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; namespace detail { template struct is_iterator : std::false_type { }; template struct is_iterator()), decltype(++std::declval>()), decltype(--std::declval>()), decltype(std::declval>()++), decltype(std::declval>()--)>> : std::true_type { }; template inline constexpr bool is_iterator_v = is_iterator::value; template struct is_random_access_iterator : std::false_type { }; template struct is_random_access_iterator() + 1), decltype(std::declval() - 1), decltype(std::declval()[1])>> : std::bool_constant> { }; template inline constexpr bool is_random_access_iterator_v = is_random_access_iterator::value; template struct is_range : std::false_type { }; template struct is_range())), decltype(end(std::declval()))>> : std::bool_constant()))>>> { }; template inline constexpr bool is_range_v = is_range::value; template struct is_random_access_range : std::false_type { }; template struct is_random_access_range> : std::bool_constant< is_range_v && is_random_access_iterator_v()))>>> { }; template inline constexpr bool is_random_access_range_v = is_random_access_range::value; } // namespace detail template inline std::enable_if_t>, bool> is_valid_axes(const Axes& axes) { using std::empty; if(empty(axes)) { return false; } using std::begin, std::end; std::vector copy(begin(axes), end(axes)); std::sort(begin(copy), end(copy)); const auto last = std::unique(begin(copy), end(copy)); return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1); } inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) { constexpr int num_execution_config_args = 2; constexpr int num_problem_args = 8; assert(num_problem_args == size(problem.shape) + size(problem.axes)); if(argc == 1) { // use default case } else if(argc == 1 + num_execution_config_args) { config.do_verification = std::stoi(argv[1]); config.time_kernel = std::stoi(argv[2]); } else if(argc == 1 + num_execution_config_args + num_problem_args) { config.do_verification = std::stoi(argv[1]); config.time_kernel = std::stoi(argv[2]); // read shape for(std::size_t idx = 0; idx < size(problem.shape); ++idx) { problem.shape[idx] = std::stoi(argv[idx + 3]); } // read axes for(std::size_t idx = 0; idx < size(problem.axes); ++idx) { problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]); } if(!is_valid_axes(problem.axes)) { std::cerr << "invalid axes: "; std::copy(begin(problem.axes), end(problem.axes), std::ostream_iterator(std::cerr, " ")); std::cerr << std::endl; return false; } } else { std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl << "arg2: time kernel (0=no, 1=yes)" << std::endl << "arg3 ~ arg6: shape for 4D tensor" << std::endl << "arg7 ~ arg10: axes to permute" << std::endl; return false; } return true; }