#ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP #include "sequence.hpp" #include "sequence_helper.hpp" #include "array.hpp" #include "tuple.hpp" #include "tuple_helper.hpp" #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" namespace ck { template __host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) { Array r; static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); r(Number{}) = x; return r; } template __host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) { return container_concat(make_tuple(x), a); } template __host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) { return container_concat(a, make_tuple(x)); } template __host__ __device__ constexpr auto container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) { static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); return make_array(old_array[Number{}]...); } template __host__ __device__ constexpr auto container_reorder_given_old2new(const Array& old_array, Sequence old2new) { return container_reorder_given_new2old( old_array, typename sequence_map_inverse::type{}); } template __host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, Sequence /*new2old*/) { static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); return make_tuple(old_tuple[Number{}]...); } template __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, Sequence old2new) { return container_reorder_given_new2old( old_tuple, typename sequence_map_inverse::type{}); } template __host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, Sequence /*new2old*/) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); return Sequence::At(Number{})...>{}; } template __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, Sequence /* old2new */) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); constexpr auto new2old = typename sequence_map_inverse>::type{}; return container_reorder_give_new2old(old_seq, new2old); } #if !CK_WORKAROUND_SWDEV_275126 // rocm-4.1 compiler would crash for recursive lambda template __host__ __device__ constexpr auto container_reduce(const Container& x, Reduce reduce, Init init, Number = Number<0>{}, Number = Number{}, Number = Number<1>{}) { static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); // f is recursive function, fs is a dummy of f // i is index, y_old is current scan, r_old is current reduction auto f = [&](auto fs, auto i, auto r_old) { auto r_new = reduce(x[i], r_old); if constexpr(i.value < IEnd - IStep) { // recursively call f/fs return fs(fs, i + Number{}, r_new); } else { return r_new; } }; // start recursion return f(f, Number{}, init); } #else // i is index, y_old is current scan, r_old is current reduction template __host__ __device__ constexpr auto container_reduce_impl( const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) { auto r_new = reduce(x[i], r_old); if constexpr(i.value < IEnd - IStep) { return container_reduce_impl( x, reduce, r_new, i + Number{}, Number{}, Number{}); } else { return r_new; } } // rocm-4.1 compiler would crash for recursive lambda // container reduce with initial value template __host__ __device__ constexpr auto container_reduce(const Container& x, Reduce reduce, Init init, Number = Number<0>{}, Number = Number{}, Number = Number<1>{}) { static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); return container_reduce_impl( x, reduce, init, Number{}, Number{}, Number{}); } #endif template __host__ __device__ constexpr auto container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) { Array y; TData r = init; static_for{}([&](auto i) { r = f(r, x[i]); y(i) = r; }); r = f(r, x[Number<0>{}]); y(Number<0>{}) = r; return y; } template __host__ __device__ constexpr auto container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) { Array y; TData r = init; static_for{}([&](auto i) { y(i) = r; r = f(r, x[i]); }); y(Number<0>{}) = r; return y; } #if !CK_WORKAROUND_SWDEV_275126 // rocm4.1 compiler would crash with recursive lambda template __host__ __device__ constexpr auto container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) { constexpr index_t NSize = sizeof...(Xs); // f is recursive function, fs is a dummy of f // i is index, y_old is current scan, r_old is current reduction auto f = [&](auto fs, auto i, auto y_old, auto r_old) { auto r_new = reduce(x[i], r_old); auto y_new = container_push_front(y_old, r_new); if constexpr(i.value > 1) { // recursively call f/fs return fs(fs, i - Number<1>{}, y_new, r_new); } else { return y_new; } }; // start recursion return f(f, Number{}, make_tuple(init), init); } #else // i is index, y_old is current scan, r_old is current reduction template __host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) { auto r_new = reduce(x[i], r_old); auto y_new = container_push_front(y_old, r_new); if constexpr(i.value > 1) { // recursively call f/fs return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); } else { return y_new; } } template __host__ __device__ constexpr auto container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) { constexpr index_t NSize = sizeof...(Xs); return container_reverse_exclusive_scan_impl( x, reduce, Number{}, make_tuple(init), init); } #endif // TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> template __host__ __device__ constexpr auto container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) { constexpr index_t NSize = sizeof...(Xs); Tuple y; TData r = init; static_for{}([&](auto i) { r = f(r, x[i]); y(i) = r; }); r = f(r, x[Number<0>{}]); y(Number<0>{}) = r; return y; } template __host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) { return container_concat(x, container_concat(ys...)); } template __host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) { return unpack2( [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); } template __host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) { return unpack2( [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); } template __host__ __device__ constexpr auto container_concat(const Container& x) { return x; } template __host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) { static_assert(N >= sizeof...(Is), "wrong! size"); return make_array(arr[Number{}]...); } template __host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) { static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); return make_tuple(tup[Number{}]...); } template __host__ __device__ constexpr void set_container_subset(Array& y, Sequence picks, const Array& x) { static_assert(N >= sizeof...(Is), "wrong! size"); static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } template __host__ __device__ constexpr void set_container_subset(Tuple& y, Sequence picks, const Tuple& x) { static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } template __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) { using Seq = Sequence; return generate_tuple( [&](auto i) { constexpr index_t tmp = Seq::At(i); return Number{}; }, Seq::Size()); } } // namespace ck #endif