// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_permute.hpp" #include "ck/utility/type.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" using F16 = ck::half_t; using F32 = float; using F64 = double; struct Problem final { static constexpr std::size_t NumDim = 3; using Shape = std::array; using Axes = Shape; Problem() = delete; explicit Problem(const Shape& default_shape, const Axes& default_axes) : shape(default_shape), axes(default_axes) { } Shape shape; Axes axes; }; template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; namespace detail { template struct enlarge_array_size; template struct enlarge_array_size, Difference> { using type = std::array; }; template using enlarge_array_size_t = typename enlarge_array_size::type; template struct get_array_size; template struct get_array_size> : std::integral_constant { }; template inline constexpr std::size_t get_array_size_v = get_array_size::value; template struct is_iterator : std::false_type { }; template struct is_iterator()), decltype(++std::declval>()), decltype(std::declval>()++)>> : std::true_type { }; template inline constexpr bool is_iterator_v = is_iterator::value; struct Placeholder final { template constexpr inline operator T() const noexcept; }; template struct is_output_iterator : std::false_type { }; template struct is_output_iterator< Iterator, std::void_t() = std::declval())>> : std::bool_constant> { }; template inline constexpr bool is_output_iterator_v = is_output_iterator::value; template struct is_bidirectional_iterator : std::false_type { }; template struct is_bidirectional_iterator< Iterator, std::void_t>()), decltype(std::declval>()--)>> : std::bool_constant> { }; template inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator::value; template struct is_random_access_iterator : std::false_type { }; template struct is_random_access_iterator() + 1), decltype(std::declval() - 1), decltype(std::declval()[1])>> : std::bool_constant> { }; template inline constexpr bool is_random_access_iterator_v = is_random_access_iterator::value; template struct is_range : std::false_type { }; template struct is_range())), decltype(end(std::declval()))>> : std::bool_constant()))>>> { }; template inline constexpr bool is_range_v = is_range::value; template struct is_sized_range : std::false_type { }; template struct is_sized_range()))>> : std::bool_constant> { }; template inline constexpr bool is_sized_range_v = is_sized_range::value; template struct is_bidirectional_range : std::false_type { }; template struct is_bidirectional_range> : std::bool_constant< is_range_v && is_bidirectional_iterator_v()))>>> { }; template inline constexpr bool is_bidirectional_range_v = is_bidirectional_range::value; template struct is_random_access_range : std::false_type { }; template struct is_random_access_range> : std::bool_constant< is_range_v && is_random_access_iterator_v()))>>> { }; template inline constexpr bool is_random_access_range_v = is_random_access_range::value; } // namespace detail namespace ranges { template inline auto copy(InputRange&& range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward(range)), std::end(std::forward(range)), iter)) { return std::copy(std::begin(std::forward(range)), std::end(std::forward(range)), iter); } } // namespace ranges template inline std::enable_if_t, bool> is_valid_axes(const Axes& axes) { using std::empty; if(empty(axes)) { return false; } using std::begin, std::end; std::vector sorted_axes(begin(axes), end(axes)); std::sort(begin(sorted_axes), end(sorted_axes)); const auto last = std::unique(begin(sorted_axes), end(sorted_axes)); return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) && (*std::prev(last) == size(axes) - 1); } template inline std::enable_if_t, bool> is_valid_shape(const Shape& shape) { using std::begin, std::end; using std::empty; return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; }); } template inline std::enable_if_t && detail::is_sized_range_v, bool> is_valid_indices(const Shape& shape, const Indices& indices) { if(!is_valid_shape(shape)) { return false; } using std::empty; if(empty(indices)) { return false; } using std::size; if(size(shape) != size(indices)) { return false; } using std::begin, std::end; auto dim = begin(shape); auto idx = begin(indices); for(; dim != end(shape) && idx != end(indices); ++dim, ++idx) { if(*dim <= *idx) { return false; } } return true; } template std::array transpose(const std::array& shape, const std::array& axes) { assert(is_valid_shape(shape) && is_valid_axes(axes)); std::array transposed; auto iter = std::begin(transposed); for(const auto axis : axes) { *iter++ = shape[axis]; } return transposed; } auto extend_shape(const Problem::Shape& shape, std::size_t new_dim) { detail::enlarge_array_size_t extended_shape; using std::begin, std::end; std::copy(begin(shape), end(shape), begin(extended_shape)); extended_shape.back() = new_dim; return extended_shape; } auto extend_axes(const Problem::Axes& axes) { detail::enlarge_array_size_t extended_axes; using std::begin, std::end; std::copy(begin(axes), end(axes), begin(extended_axes)); extended_axes.back() = detail::get_array_size_v; return extended_axes; } template std::enable_if_t && detail::is_sized_range_v && detail::is_bidirectional_range_v && detail::is_sized_range_v, bool> advance_indices(const Shape& shape, Indices& indices) { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) { return false; } bool carry = true; using std::rbegin, std::rend; auto dim = rbegin(shape); auto idx = rbegin(indices); for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx) { assert(*idx < *dim); *idx = (*idx + carry); carry = ((*idx == *dim) ? (*idx = 0, true) : false); } return !carry; } template std::enable_if_t && detail::is_sized_range_v && std::is_invocable_v, std::add_lvalue_reference_t>, bool> host_permute(const Tensor& src, const Axes& axes, Functor functor, Tensor& dest) { const auto& shape = src.mDesc.GetLengths(); const auto& transposed_shape = dest.mDesc.GetLengths(); if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape))) { return false; } using std::size; if(!is_valid_axes(axes)) { return false; } static_assert(detail::is_sized_range_v> && detail::is_sized_range_v>); if(size(shape) != size(transposed_shape)) { return false; } static_assert(detail::is_random_access_range_v> && detail::is_random_access_range_v>); { for(std::size_t idx = 0; idx < size(shape); ++idx) { if(transposed_shape[idx] != shape[axes[idx]]) { return false; } } } std::vector indices(size(shape), 0); if(!is_valid_indices(shape, indices)) { return false; } switch(size(shape)) { case 3: { do { Dest output = 0; functor(output, src(indices[0], indices[1], indices[2])); dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output; } while(advance_indices(shape, indices)); } break; case 4: { do { Dest output = 0; functor(output, src(indices[0], indices[1], indices[2], indices[3])); dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output; } while(advance_indices(shape, indices)); } break; default: return false; } return true; }