#ifndef DEVICE_REDUCE_COMMON_HPP #define DEVICE_REDUCE_COMMON_HPP #include #include "common_header.hpp" #include "reduction_enums.hpp" #include "reduction_operator.hpp" namespace ck { namespace tensor_operation { namespace device { // template // using DeviceReducePtr = std::unique_ptr>; template std::pair get_2d_lengths(const std::vector& inLengths) { static_assert(Rank <= 6, "bigger Rank size not supported!"); size_t tensor_total_length = 1; size_t reduce_total_length = 1; static_for<0, ReduceDims::Size(), 1>{}( [&](auto i) { reduce_total_length *= inLengths[ReduceDims::At(i)]; }); static_for<0, Rank, 1>{}([&](auto i) { tensor_total_length *= inLengths[i.value]; }); return std::make_pair(tensor_total_length / reduce_total_length, reduce_total_length); }; template constexpr bool belong() { bool inside = false; static_for<0, Seq::Size(), 1>{}([&](auto i) { inside = (inside || (x == Seq::At(i))); }); return (inside); }; template constexpr auto get_invariant_dims() { static_assert(Rank <= 6, "bigger Rank size not supported!"); if constexpr(start >= Rank) return Sequence<>{}; else { if constexpr(!belong()) return merge_sequences(Sequence{}, get_invariant_dims()); else return get_invariant_dims(); }; }; // helper functions using variadic template arguments template static auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) { return make_tuple(static_cast(lengths[Ns])...); }; template static auto make_tuple_from_array(const std::vector& lengths, Number) { static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; return make_tuple_from_array_and_index_seq(lengths, index_seq); }; } // namespace device } // namespace tensor_operation } // namespace ck #endif