// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_permute.hpp" #include "ck/utility/type.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" using F16 = ck::half_t; struct ExecutionConfig final { bool do_verification = true; bool time_kernel = false; }; struct Problem final { std::array shape = {4, 16, 32, 32}; std::array axes = {0, 2, 3, 1}; }; template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; namespace detail { template struct is_iterator : std::false_type { }; template struct is_iterator()), decltype(++std::declval>()), decltype(std::declval>()++)>> : std::true_type { }; template inline constexpr bool is_iterator_v = is_iterator::value; struct Placeholder final { template constexpr inline operator T() const noexcept; }; template struct is_output_iterator : std::false_type { }; template struct is_output_iterator< Iterator, std::void_t() = std::declval())>> : std::bool_constant> { }; template inline constexpr bool is_output_iterator_v = is_output_iterator::value; template struct is_bidirectional_iterator : std::false_type { }; template struct is_bidirectional_iterator< Iterator, std::void_t>()), decltype(std::declval>()--)>> : std::bool_constant> { }; template inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator::value; template struct is_random_access_iterator : std::false_type { }; template struct is_random_access_iterator() + 1), decltype(std::declval() - 1), decltype(std::declval()[1])>> : std::bool_constant> { }; template inline constexpr bool is_random_access_iterator_v = is_random_access_iterator::value; template struct is_range : std::false_type { }; template struct is_range())), decltype(end(std::declval()))>> : std::bool_constant()))>>> { }; template inline constexpr bool is_range_v = is_range::value; template struct is_sized_range : std::false_type { }; template struct is_sized_range()))>> : std::bool_constant> { }; template inline constexpr bool is_sized_range_v = is_sized_range::value; template struct is_bidirectional_range : std::false_type { }; template struct is_bidirectional_range> : std::bool_constant< is_range_v && is_bidirectional_iterator_v()))>>> { }; template inline constexpr bool is_bidirectional_range_v = is_bidirectional_range::value; template struct is_random_access_range : std::false_type { }; template struct is_random_access_range> : std::bool_constant< is_range_v && is_random_access_iterator_v()))>>> { }; template inline constexpr bool is_random_access_range_v = is_random_access_range::value; } // namespace detail template inline std::enable_if_t, bool> is_valid_axes(const Axes& axes) { using std::empty; if(empty(axes)) { return false; } using std::begin, std::end; std::vector sorted_axes(begin(axes), end(axes)); std::sort(begin(sorted_axes), end(sorted_axes)); const auto last = std::unique(begin(sorted_axes), end(sorted_axes)); return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) && (*std::prev(last) == size(axes) - 1); } inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) { constexpr int num_execution_config_args = 2; constexpr int num_problem_args = 8; if(!(num_problem_args == size(problem.shape) + size(problem.axes))) { return false; } if(argc == 1) { // use default case } else if(argc == 1 + num_execution_config_args) { config.do_verification = std::stoi(argv[1]); config.time_kernel = std::stoi(argv[2]); } else if(argc == 1 + num_execution_config_args + num_problem_args) { config.do_verification = std::stoi(argv[1]); config.time_kernel = std::stoi(argv[2]); // read shape for(std::size_t idx = 0; idx < size(problem.shape); ++idx) { problem.shape[idx] = std::stoi(argv[idx + 3]); } // read axes for(std::size_t idx = 0; idx < size(problem.axes); ++idx) { problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]); } if(!is_valid_axes(problem.axes)) { std::cerr << "invalid axes: "; std::copy(begin(problem.axes), end(problem.axes), std::ostream_iterator(std::cerr, " ")); std::cerr << std::endl; return false; } } else { std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl << "arg2: time kernel (0=no, 1=yes)" << std::endl << "arg3 ~ arg6: shape for 4D tensor" << std::endl << "arg7 ~ arg10: axes to permute" << std::endl; return false; } return true; } template inline std::enable_if_t, bool> is_valid_shape(const Shape& shape) { using std::begin, std::end; using std::empty; return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; }); } template inline std::enable_if_t && detail::is_sized_range_v, bool> is_valid_indices(const Shape& shape, const Indices& indices) { if(!is_valid_shape(shape)) { return false; } using std::empty; if(empty(indices)) { return false; } using std::size; if(size(shape) != size(indices)) { return false; } using std::begin, std::end; auto dim = begin(shape); auto idx = begin(indices); for(; dim != end(shape) && idx != end(indices); ++dim, ++idx) { if(*dim <= *idx) { return false; } } return true; } template inline std::enable_if_t && detail::is_sized_range_v && detail::is_sized_range_v && detail::is_output_iterator_v, OutputIterator> transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter) { using std::size; assert(size(shape) == size(axes)); assert(is_valid_shape(shape) && is_valid_axes(axes)); for(const auto axis : axes) { *iter++ = shape[axis]; } return iter; } template std::enable_if_t && detail::is_sized_range_v && detail::is_bidirectional_range_v && detail::is_sized_range_v, bool> advance_indices(const Shape& shape, Indices& indices) { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) { return false; } bool carry = true; using std::rbegin, std::rend; auto dim = rbegin(shape); auto idx = rbegin(indices); for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx) { assert(*idx < *dim); *idx = (*idx + carry); carry = ((*idx == *dim) ? (*idx = 0, true) : false); } return !carry; } template std::enable_if_t && detail::is_sized_range_v && std::is_invocable_v, std::add_lvalue_reference_t>, bool> host_permute(const Tensor& src, const Axes& axes, Functor functor, Tensor& dest) { const auto& shape = src.mDesc.GetLengths(); const auto& transposed_shape = dest.mDesc.GetLengths(); if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape))) { return false; } using std::size; if(!(is_valid_axes(axes) && size(axes) == 4)) { return false; } static_assert(detail::is_sized_range_v> && detail::is_sized_range_v>); if(!(size(shape) == 4 && size(transposed_shape) == 4)) { return false; } static_assert(detail::is_random_access_range_v> && detail::is_random_access_range_v>); { for(std::size_t idx = 0; idx < size(shape); ++idx) { if(transposed_shape[idx] != shape[axes[idx]]) { return false; } } } std::array indices{}; if(!is_valid_indices(shape, indices)) { return false; } do { Dest output = 0; functor(output, src(indices[0], indices[1], indices[2], indices[3])); dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output; } while(advance_indices(shape, indices)); return true; }