#pragma once #include "constant_integral.hip.hpp" #include "functional.hip.hpp" template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); const index_t mData[mSize + 1] = { Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } template __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { static_assert(mSize == sizeof...(IRs), "mSize not consistent"); constexpr auto old = Type{}; return Sequence{})...>{}; } template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // TODO: don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } __host__ __device__ constexpr auto Reverse() const; __host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } __host__ __device__ constexpr auto PopFront() const; __host__ __device__ constexpr auto PopBack() const; template __host__ __device__ constexpr auto Append(Sequence) const { return Sequence{}; } template __host__ __device__ constexpr auto Extract(Number...) const { return Sequence{})...>{}; } template __host__ __device__ constexpr auto Extract(Sequence) const { return Sequence{})...>{}; } }; template struct sequence_merge; template struct sequence_merge, Sequence> { using SeqType = Sequence; }; template struct increasing_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; using SeqType = typename sequence_merge< typename increasing_sequence_gen_impl::SeqType, typename increasing_sequence_gen_impl::SeqType>::SeqType; }; template struct increasing_sequence_gen_impl { using SeqType = Sequence; }; template struct increasing_sequence_gen_impl { using SeqType = Sequence<>; }; template struct increasing_sequence_gen { using SeqType = typename increasing_sequence_gen_impl::SeqType; }; template __host__ __device__ constexpr auto make_increasing_sequence(Number, Number, Number) { return typename increasing_sequence_gen::SeqType{}; } template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce>::SeqType; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using SeqType = typename sequence_merge, old_scan>::SeqType; }; template struct sequence_reverse_inclusive_scan, Reduce> { using SeqType = Sequence; }; template struct sequence_extract; template struct sequence_extract> { using SeqType = Sequence{})...>; }; template struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType; using range1 = typename increasing_sequence_gen::SeqType; using SeqType0 = typename sequence_extract::SeqType; using SeqType1 = typename sequence_extract::SeqType; }; template struct sequence_reverse { static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; using SeqType = typename sequence_merge< typename sequence_reverse::SeqType, typename sequence_reverse::SeqType>::SeqType; }; template struct sequence_reverse> { using SeqType = Sequence; }; template struct sequence_reverse> { using SeqType = Sequence; }; template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_for<0, seq_x.GetSize(), 1>{}( [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { #if 0 // doesn't compile constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); }); #endif return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); }); return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq{}.GetSize() > 0, "empty Sequence!"); return sequence_pop_front(Seq{}.Reverse()).Reverse(); } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto Sequence::PopFront() const { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); } template struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } template __host__ __device__ constexpr auto Sequence::Reverse() const { return typename sequence_reverse>::SeqType{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce) { return typename sequence_reverse_inclusive_scan::SeqType{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse(); }