#ifndef CK_SEQUENCE_HPP #define CK_SEQUENCE_HPP #include "integral_constant.hpp" #include "functional.hpp" namespace ck { template struct static_for; template struct Sequence; template struct sequence_split; template struct sequence_reverse; template struct sequence_map_inverse; template struct is_valid_sequence_map; template __host__ __device__ constexpr auto sequence_pop_front(Sequence); template __host__ __device__ constexpr auto sequence_pop_back(Seq); template struct Sequence { using Type = Sequence; using data_type = index_t; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr auto GetSize() { return Number{}; } __host__ __device__ static constexpr index_t GetImpl(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto Get(Number) { static_assert(I < mSize, "wrong! I too large"); return Number{})>{}; } __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); } template __host__ __device__ constexpr auto operator[](Number) const { return Get(Number{}); } // make sure I is constepxr if you want a constexpr return type __host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); return Sequence{})...>{}; } // MapOld2New is Sequence<...> template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { static_assert(MapOld2New::GetSize() == GetSize(), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); } __host__ __device__ static constexpr auto Reverse() { return typename sequence_reverse::type{}; } __host__ __device__ static constexpr auto Front() { static_assert(mSize > 0, "wrong!"); return Get(Number<0>{}); } __host__ __device__ static constexpr auto Back() { static_assert(mSize > 0, "wrong!"); return Get(Number{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ static constexpr auto PushFront(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushFront(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number) { static_assert(I < GetSize(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); return seq_left.PushBack(Number{}).PushBack(seq_right); } template __host__ __device__ static constexpr auto Transform(F f) { return Sequence{}; } }; // merge sequence template struct sequence_merge; template struct sequence_merge, Sequence> { using type = Sequence; }; // generate sequence template struct sequence_gen_impl { static constexpr index_t NRemainLeft = NRemain / 2; static constexpr index_t NRemainRight = NRemain - NRemainLeft; static constexpr index_t IMiddle = IBegin + NRemainLeft; using type = typename sequence_merge::type, typename sequence_gen_impl::type>::type; }; template struct sequence_gen_impl { static constexpr index_t Is = F{}(Number{}); using type = Sequence; }; template struct sequence_gen_impl { using type = Sequence<>; }; template struct sequence_gen { using type = typename sequence_gen_impl<0, NSize, F>::type; }; // arithmetic sequence template struct arithmetic_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t i) const { return i * Increment + IBegin; } }; using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; }; // uniform sequence template struct uniform_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; using type = typename sequence_gen::type; }; // reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using type = typename sequence_merge, old_scan>::type; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence<>; }; // split sequence template struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; using SeqType0 = decltype(Seq::Extract(range0{})); using SeqType1 = decltype(Seq::Extract(range1{})); }; // reverse sequence template struct sequence_reverse { static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; using type = typename sequence_merge< typename sequence_reverse::type, typename sequence_reverse::type>::type; }; template struct sequence_reverse> { using type = Sequence; }; template struct sequence_reverse> { using type = Sequence; }; template struct sequence_sort { // not implemented }; template struct sequence_unique_sort { // not implemented }; template struct is_valid_sequence_map { // not implemented yet, always return true static constexpr integral_constant value = integral_constant{}; // TODO: add proper check for is_valid, something like: // static constexpr bool value = // is_same::type, // typename sequence_sort::SortedSeqType>{}; }; template struct sequence_map_inverse_impl { private: static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::Get(Number{}), Number{}); public: using type = typename sequence_map_inverse_impl::type; }; template struct sequence_map_inverse_impl { using type = WorkingY2X; }; template struct sequence_map_inverse { using type = typename sequence_map_inverse_impl::type, 0, X2Y::GetSize()>::type; }; template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq::Reverse()).Reverse(); } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::type{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } template struct lambda_accumulate_on_sequence { const Reduce& f; index_t& result; __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_) : f(f_), result(result_) { } template __host__ __device__ constexpr index_t operator()(IDim) const { return result = f(result, Seq::Get(IDim{})); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce f, Number /*initial_value*/) { index_t result = Init; static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence(f, result)); return result; } template __host__ __device__ void print_Sequence(const char* s, Sequence) { constexpr index_t nsize = Sequence::GetSize(); static_assert(nsize <= 10, "wrong!"); static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); } } // namespace ck #endif