#ifndef CK_FUNCTIONAL3_HPP #define CK_FUNCTIONAL3_HPP #include "composable_kernel/utility/functional.hpp" #include "composable_kernel/utility/functional2.hpp" #include "composable_kernel/utility/Sequence.hpp" #include "composable_kernel/utility/Array.hpp" namespace ck { // RemainLengths: Sequence<...> template struct static_ford_impl { // F signature: F(Sequence<...> multi_id) // CurrentMultiIndex: Sequence<...> template __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const { static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { static_ford_impl{}(f, CurrentMultiIndex::PushBack(I)); }); } }; template <> struct static_ford_impl> { // F signature: F(Sequence<...> multi_id) // CurrentMultiIndex: Sequence<...> template __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const { f(CurrentMultiIndex{}); } }; // Lengths is Sequence<...> template struct static_ford { // F signature: F(Sequence<...> multi_id) template __host__ __device__ constexpr void operator()(F f) const { static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); static_ford_impl{}(f, Sequence<>{}); } }; template struct ford_impl { // F signature: F(Array<...> multi_id) // CurrentMultiIndex: Array<...> // RemainLengths: Sequence<...> template __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const { static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); static_assert(RemainDim > 1, "wrong!"); constexpr auto next_length = RemainLengths{}.Front(); for(index_t i = 0; i < next_length; ++i) { ford_impl{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront()); } } }; template <> struct ford_impl<1> { // F signature: F(Array<...> multi_id) // CurrentMultiIndex: Array<...> // RemainLengths: Sequence<...> template __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const { static_assert(RemainLengths::GetSize() == 1, "wrong!"); constexpr index_t last_length = RemainLengths{}.Front(); for(index_t i = 0; i < last_length; ++i) { f(current_multi_id.PushBack(i)); } } }; // Lengths is Sequence<...> template struct ford { // F signature: F(Array<...> multi_id) template __host__ __device__ constexpr void operator()(F f) const { constexpr index_t first_length = Lengths{}.Front(); for(index_t i = 0; i < first_length; ++i) { ford_impl{}(f, Array{i}, Lengths{}.PopFront()); } } }; } // namespace ck #endif