// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "functional4.hpp" #include "tuple.hpp" #include "is_detected.hpp" namespace ck { template __host__ __device__ constexpr auto generate_tuple(F&& f, Number) { return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, typename arithmetic_sequence_gen<0, N, 1>::type{}); } template __host__ __device__ constexpr auto generate_tie(F&& f, Number) { return unpack([&f](auto&&... xs) { return tie(f(xs)...); }, typename arithmetic_sequence_gen<0, N, 1>::type{}); } // tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) template __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, const Tuple& ty) { return unpack2( [&](auto&&... zs) { return Tuple{ck::forward(zs)...}; }, tx, ty); } template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) { return unpack2( [&](auto... zs) { return Tuple{std::forward(zs)...}; }, tx, ty); } // Support any number of tuples to concat (also 1) template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx) { return tx; } template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuples&... tuples) { return concat_tuple(tx, concat_tuple(tuples...)); } namespace detail { template __host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) { return make_tuple(f(x.At(Number{}))...); } template __host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, const Y& y, Sequence) { return make_tuple(f(x.At(Number{}), y.At(Number{}))...); } template __host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) { return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); } } // namespace detail template __host__ __device__ constexpr auto transform_tuples(F f, const X& x) { return detail::transform_tuples_impl( f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } template __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) { return detail::transform_tuples_impl( f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } template __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) { return detail::transform_tuples_impl( f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } // By default unroll to the flatten template __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) { return element; } template __host__ __device__ constexpr auto UnrollNestedTuple(const T& element) { return make_tuple(element); } template __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple& tuple) { if constexpr(Depth == MaxDepth) { return tuple; } else { return unpack( [&](auto&&... ts) { return concat_tuple(UnrollNestedTuple(ts)...); }, tuple); } } template __host__ __device__ constexpr auto TupleReverse(const Tuple& tuple) { return generate_tuple( [&](auto i) { using Idx = Number::Size() - i - 1>; return tuple.At(Idx{}); }, Number::Size()>{}); } // Reduce tuple values in specific range using Function template __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple& tuple) { static_assert(Idx < End, "Wrong parameters for TupleReduce"); if constexpr(Idx + 1 == End) { return tuple.At(Number{}); } else { return f(tuple.At(Number{}), TupleReduce(f, tuple)); } } template using is_tuple = decltype(std::declval().IsTuple()); template __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) { return (is_detected::value || ...); } } // namespace ck