#ifndef CK_SEQUENCE_HPP #define CK_SEQUENCE_HPP #include "integral_constant.hpp" #include "functional.hpp" namespace ck { template struct is_valid_sequence_map; template struct Sequence { using Type = Sequence; using data_type = index_t; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr auto GetSize() { return Number{}; } __host__ __device__ static constexpr index_t GetImpl(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto Get(Number) { static_assert(I < mSize, "wrong! I too large"); return Number{})>{}; } template __host__ __device__ constexpr auto operator[](Number) const { return Get(Number{}); } // make sure I is constepxr if you want a constexpr return type __host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); return Sequence{})...>{}; } __host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Front() { static_assert(mSize > 0, "wrong!"); return Get(Number<0>{}); } __host__ __device__ static constexpr auto Back() { static_assert(mSize > 0, "wrong!"); return Get(Number{}); } __host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopBack(); template __host__ __device__ static constexpr auto PushFront(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushFront(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number); template __host__ __device__ static constexpr auto Transform(F f) { return Sequence{}; } }; // merge sequence template struct sequence_merge; template struct sequence_merge, Sequence> { using type = Sequence; }; // arithmetic sqeuence template struct arithmetic_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; using type = typename sequence_merge< typename arithmetic_sequence_gen_impl::type, typename arithmetic_sequence_gen_impl::type>::type; }; template struct arithmetic_sequence_gen_impl { using type = Sequence; }; template struct arithmetic_sequence_gen_impl { using type = Sequence<>; }; template struct arithmetic_sequence_gen { using type = typename arithmetic_sequence_gen_impl::type; }; // uniform sequence template struct uniform_sequence_gen { struct return_constant { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; using type = decltype( typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{})); }; // reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using type = typename sequence_merge, old_scan>::type; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence<>; }; // split sequence template struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; using SeqType0 = decltype(Seq::Extract(range0{})); using SeqType1 = decltype(Seq::Extract(range1{})); }; // reverse sequence template struct sequence_reverse { static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; using type = typename sequence_merge< typename sequence_reverse::type, typename sequence_reverse::type>::type; }; template struct sequence_reverse> { using type = Sequence; }; template struct sequence_reverse> { using type = Sequence; }; template struct is_valid_sequence_map { static constexpr integral_constant value = integral_constant{}; // TODO: add proper check for is_valid, something like: // static constexpr bool value = // is_same::type, // typename sequence_sort::SortedSeqType>{}; }; template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq{}.Reverse()).Reverse(); } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::type{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } template __host__ __device__ constexpr auto Sequence::PopFront() { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ constexpr auto Sequence::Reverse() { return typename sequence_reverse>::type{}; } template template __host__ __device__ constexpr auto Sequence::Modify(Number, Number) { static_assert(I < GetSize(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); return seq_left.PushBack(Number{}).PushBack(seq_right); } template __host__ __device__ void print_Sequence(const char* s, Sequence) { constexpr index_t nsize = Sequence::GetSize(); static_assert(nsize <= 10, "wrong!"); static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); } } // namespace ck #endif