#pragma once #include "constant_integral.hip.hpp" #include "functional.hip.hpp" struct EmptySequence; template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); const index_t mData[mSize] = {Is...}; __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } template __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { static_assert(mSize == sizeof...(IRs), "mSize not consistent"); constexpr auto old = Type{}; return Sequence{})...>{}; } template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // TODO: don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } __host__ __device__ constexpr auto Reverse() const { // not implemented } __host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } __host__ __device__ constexpr auto PopFront() const; __host__ __device__ constexpr auto PopBack() const; template __host__ __device__ constexpr auto Append(Sequence) const { return Sequence{}; } __host__ __device__ constexpr auto Append(EmptySequence) const; template __host__ __device__ constexpr auto Extract(Number...) const { return Sequence{})...>{}; } template struct split_impl { template __host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const { constexpr index_t new_first = FirstSeq{}.PushBack(Number{}); constexpr index_t new_second = SecondSeq{}.PopFront(); static_if<(N > 0)>{}([&](auto fwd) { return split_impl{}(new_first, fwd(new_second)); }).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); }); } }; // split one sequence to two sequnces: [0, I) and [I, mSize) // return type is std::pair template __host__ __device__ constexpr auto Split(Number) const; template __host__ __device__ constexpr auto Modify(Number, Number) const { constexpr auto first_second = Split(Number{}); constexpr auto left = first_second.first; constexpr auto right = first_second.second.PopFront(); return left.PushBack(Number{}).Append(right); } }; struct EmptySequence { __host__ __device__ static constexpr index_t GetSize() { return 0; } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } template __host__ __device__ constexpr Seq Append(Seq) const { return Seq{}; } }; template __host__ __device__ constexpr auto Sequence::Append(EmptySequence) const { return Type{}; } // split one sequence to two sequnces: [0, I) and [I, mSize) // return type is std::pair template template __host__ __device__ constexpr auto Sequence::Split(Number) const { static_assert(I <= GetSize(), "wrong! split position is too high!"); static_if<(I == 0)>{}([&](auto fwd) { return std::make_pair(EmptySequence{}, fwd(Type{})); }); static_if<(I == GetSize())>{}( [&](auto fwd) { return std::make_pair(Type{}, fwd(EmptySequence{})); }); static_if<(I > 0 && I < GetSize())>{}( [&](auto fwd) { return split_impl{}(EmptySequence{}, fwd(Type{})); }); } #if 0 template __host__ __device__ auto make_increasing_sequence(Number, Number, Number) { static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!"); // not implemented } #endif template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_for<0, seq_x.GetSize(), 1>{}( [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { constexpr auto seq_x = Sequence{}; #if 0 static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); }); #endif return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); }); return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #if 0 // TODO: for some reason, compiler cannot instantiate this template template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #else // TODO: delete these very ugly mess template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } #endif template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto Sequence::PopFront() const { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); } template struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } template struct scan_sequence_impl { template __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const { static_assert(RemainSeq{}.GetSize() == NRemain, "wrong! RemainSeq and NRemain not consistent!"); constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front()); constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number{}); static_if<(NRemain > 1)>{}([&](auto fwd) { return scan_sequence_impl{}( scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{})); }).else_([&](auto fwd) { return fwd(scaned_seq); }); } }; template __host__ __device__ constexpr auto scan_sequence(Seq, Reduce) { constexpr auto scaned_seq = Sequence{}; constexpr auto remain_seq = Seq{}.PopFront(); constexpr index_t remain_size = Seq::GetSize() - 1; return scan_sequence_impl{}(scaned_seq, remain_seq, Reduce{}); } template __host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) { return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); }