#ifndef CK_SEQUENCE_HPP #define CK_SEQUENCE_HPP #include "integral_constant.hpp" #include "type.hpp" #include "functional.hpp" #include "math.hpp" namespace ck { template struct static_for; template struct Sequence; template struct sequence_split; template struct sequence_reverse; template struct sequence_map_inverse; template struct is_valid_sequence_map; template __host__ __device__ constexpr auto sequence_pop_front(Sequence); template __host__ __device__ constexpr auto sequence_pop_back(Seq); template struct Sequence { using Type = Sequence; using data_type = index_t; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr auto Size() { return Number{}; } __host__ __device__ static constexpr auto GetSize() { return Size(); } __host__ __device__ static constexpr index_t At(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto At(Number) { static_assert(I < mSize, "wrong! I too large"); return Number{}; } template __host__ __device__ static constexpr auto Get(Number) { return At(Number{}); } template __host__ __device__ constexpr auto operator[](I i) const { return At(i); } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); return Sequence{})...>{}; } // MapOld2New is Sequence<...> template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { static_assert(MapOld2New::Size() == Size(), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); } __host__ __device__ static constexpr auto Reverse() { return typename sequence_reverse::type{}; } __host__ __device__ static constexpr auto Front() { static_assert(mSize > 0, "wrong!"); return At(Number<0>{}); } __host__ __device__ static constexpr auto Back() { static_assert(mSize > 0, "wrong!"); return At(Number{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ static constexpr auto PushFront(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushFront(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number) { static_assert(I < Size(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::left_type{}; constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); return seq_left.PushBack(Number{}).PushBack(seq_right); } template __host__ __device__ static constexpr auto Transform(F f) { return Sequence{}; } }; // merge sequence template struct sequence_merge { using type = typename sequence_merge::type>::type; }; template struct sequence_merge, Sequence> { using type = Sequence; }; template struct sequence_merge { using type = Seq; }; // generate sequence template struct sequence_gen { template struct sequence_gen_impl { static constexpr index_t NRemainLeft = NRemain / 2; static constexpr index_t NRemainRight = NRemain - NRemainLeft; static constexpr index_t IMiddle = IBegin + NRemainLeft; using type = typename sequence_merge< typename sequence_gen_impl::type, typename sequence_gen_impl::type>::type; }; template struct sequence_gen_impl { static constexpr index_t Is = G{}(Number{}); using type = Sequence; }; template struct sequence_gen_impl { using type = Sequence<>; }; using type = typename sequence_gen_impl<0, NSize, F>::type; }; // arithmetic sequence template struct arithmetic_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t i) const { return i * Increment + IBegin; } }; using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; }; // uniform sequence template struct uniform_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; using type = typename sequence_gen::type; }; // reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using type = typename sequence_merge, old_scan>::type; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence<>; }; // split sequence template struct sequence_split { static constexpr index_t NSize = Seq{}.Size(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; using left_type = decltype(Seq::Extract(range0{})); using right_type = decltype(Seq::Extract(range1{})); }; // reverse sequence template struct sequence_reverse { static constexpr index_t NSize = Seq{}.Size(); using seq_split = sequence_split; using type = typename sequence_merge< typename sequence_reverse::type, typename sequence_reverse::type>::type; }; template struct sequence_reverse> { using type = Sequence; }; template struct sequence_reverse> { using type = Sequence; }; #if 0 template struct sequence_reduce { using type = typename sequence_reduce::type>::type; }; template struct sequence_reduce, Sequence> { using type = Sequence; }; template struct sequence_reduce { using type = Seq; }; #endif template struct sequence_sort_impl { template struct sorted_sequence_merge_impl { static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); static constexpr index_t chosen_value = choose_left ? LeftValues::Front() : RightValues::Front(); static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); using new_merged_values = decltype(MergedValues::PushBack(Number{})); using new_merged_ids = decltype(MergedIds::PushBack(Number{})); using new_left_values = typename conditional::type; using new_left_ids = typename conditional::type; using new_right_values = typename conditional::type; using new_right_ids = typename conditional::type; using merge = sorted_sequence_merge_impl; // this is output using merged_values = typename merge::merged_values; using merged_ids = typename merge::merged_ids; }; template struct sorted_sequence_merge_impl, Sequence<>, MergedValues, MergedIds, Comp> { using merged_values = typename sequence_merge::type; using merged_ids = typename sequence_merge::type; }; template struct sorted_sequence_merge_impl, Sequence<>, RightValues, RightIds, MergedValues, MergedIds, Comp> { using merged_values = typename sequence_merge::type; using merged_ids = typename sequence_merge::type; }; template struct sorted_sequence_merge { using merge = sorted_sequence_merge_impl, Sequence<>, Comp>; using merged_values = typename merge::merged_values; using merged_ids = typename merge::merged_ids; }; static constexpr index_t nsize = Values::Size(); using split_unsorted_values = sequence_split; using split_unsorted_ids = sequence_split; using left_unsorted_values = typename split_unsorted_values::left_type; using left_unsorted_ids = typename split_unsorted_ids::left_type; using left_sort = sequence_sort_impl; using left_sorted_values = typename left_sort::sorted_values; using left_sorted_ids = typename left_sort::sorted_ids; using right_unsorted_values = typename split_unsorted_values::right_type; using right_unsorted_ids = typename split_unsorted_ids::right_type; using right_sort = sequence_sort_impl; using right_sorted_values = typename right_sort::sorted_values; using right_sorted_ids = typename right_sort::sorted_ids; using merged_sorted = sorted_sequence_merge; using sorted_values = typename merged_sorted::merged_values; using sorted_ids = typename merged_sorted::merged_ids; }; template struct sequence_sort_impl, Sequence, Compare> { static constexpr bool choose_x = Compare{}(ValueX, ValueY); using sorted_values = typename conditional, Sequence>::type; using sorted_ids = typename conditional, Sequence>::type; }; template struct sequence_sort_impl, Sequence, Compare> { using sorted_values = Sequence; using sorted_ids = Sequence; }; template struct sequence_sort { using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; using sort = sequence_sort_impl; // this is output using type = typename sort::sorted_values; using sorted2unsorted_map = typename sort::sorted_ids; }; template struct sequence_unique_sort { template struct sorted_sequence_uniquify_impl { static constexpr index_t current_value = RemainValues::Front(); static constexpr index_t current_id = RemainIds::Front(); static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); using new_remain_values = decltype(RemainValues::PopFront()); using new_remain_ids = decltype(RemainIds::PopFront()); using new_uniquified_values = typename conditional{})), UniquifiedValues>::type; using new_uniquified_ids = typename conditional{})), UniquifiedIds>::type; using uniquify = sorted_sequence_uniquify_impl; // this is output using uniquified_values = typename uniquify::uniquified_values; using uniquified_ids = typename uniquify::uniquified_ids; }; template struct sorted_sequence_uniquify_impl, Sequence<>, UniquifiedValues, UniquifiedIds, Eq> { using uniquified_values = UniquifiedValues; using uniquified_ids = UniquifiedIds; }; template struct sorted_sequence_uniquify { using uniquify = sorted_sequence_uniquify_impl, Sequence, Eq>; using uniquified_values = typename uniquify::uniquified_values; using uniquified_ids = typename uniquify::uniquified_ids; }; using sort = sequence_sort; using sorted_values = typename sort::type; using sorted_ids = typename sort::sorted2unsorted_map; using uniquify = sorted_sequence_uniquify; // this is output using type = typename uniquify::uniquified_values; using sorted2unsorted_map = typename uniquify::uniquified_ids; }; template struct is_valid_sequence_map : is_same::type, typename sequence_sort>::type> { }; template struct sequence_map_inverse { template struct sequence_map_inverse_impl { static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number{}), Number{}); using type = typename sequence_map_inverse_impl:: type; }; template struct sequence_map_inverse_impl { using type = WorkingY2X; }; using type = typename sequence_map_inverse_impl::type, 0, SeqMap::Size()>::type; }; template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq::Reverse()).Reverse(); } template __host__ __device__ constexpr auto merge_sequences(Seqs...) { return typename sequence_merge::type{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::type{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } template __host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) { return Sequence{})...>{}; } #if 0 template __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) { // not implemented } #endif template struct lambda_reduce_on_sequence { const Reduce& f; index_t& result; __host__ __device__ constexpr lambda_reduce_on_sequence(const Reduce& f_, index_t& result_) : f(f_), result(result_) { } template __host__ __device__ constexpr index_t operator()(IDim) const { return result = f(result, Seq::At(IDim{})); } }; template __host__ __device__ constexpr index_t reduce_on_sequence(Seq, Reduce f, Number /*initial_value*/) { index_t result = Init; static_for<0, Seq::Size(), 1>{}(lambda_reduce_on_sequence(f, result)); return result; } // TODO: a generic any_of for any container template __host__ __device__ constexpr bool sequence_any_of(Seq, F f /*initial_value*/) { bool flag = false; for(index_t i = 0; i < Seq::Size(); ++i) { flag = flag || f(Seq::At(i)); } return flag; } // TODO: a generic all_of for any container template __host__ __device__ constexpr bool sequence_all_of(Seq, F f /*initial_value*/) { bool flag = true; for(index_t i = 0; i < Seq::Size(); ++i) { flag = flag && f(Seq::At(i)); } return flag; } } // namespace ck #endif