#pragma once #include "constant_integral.hip.hpp" #include "functional.hip.hpp" template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); const index_t mData[mSize] = {Is...}; __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } template __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { static_assert(mSize == sizeof...(IRs), "mSize not consistent"); constexpr auto old = Type{}; return Sequence{})...>{}; } template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // TODO: don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } __host__ __device__ constexpr auto Reverse() const { // not implemented } __host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } __host__ __device__ constexpr auto PopFront() const; __host__ __device__ constexpr auto PopBack() const; template __host__ __device__ constexpr auto Append(Sequence) const { return Sequence{}; } template __host__ __device__ constexpr auto Extract(Number...) const { return Sequence{})...>{}; } template __host__ __device__ constexpr auto Extract(Sequence) const { return Sequence{})...>{}; } }; template struct sequence_merge; template struct sequence_merge, Sequence> { using Type = Sequence; }; template struct increasing_sequence_gen { static constexpr index_t NSizeLeft = NSize / 2; using Type = sequence_merge::Type, typename increasing_sequence_gen::Type>; }; template struct increasing_sequence_gen { using Type = Sequence; }; template struct increasing_sequence_gen { using Type = Sequence<>; }; template __host__ __device__ constexpr auto make_increasing_sequence(Number, Number, Number) { static_assert(IBegin <= IEnd && Increment > 0, "wrong!"); constexpr index_t NSize = (IEnd - IBegin) / Increment; return increasing_sequence_gen{}; } template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_for<0, seq_x.GetSize(), 1>{}( [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { constexpr auto seq_x = Sequence{}; #if 0 // doesn't compile static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); }); #endif return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); }); return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #if 0 // TODO: for some reason, compiler cannot instantiate this template template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #else // TODO: delete these very ugly mess template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } #endif template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto Sequence::PopFront() const { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); } template struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } template struct scan_sequence_impl { template __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const { static_assert(RemainSeq{}.GetSize() == NRemain, "wrong! RemainSeq and NRemain not consistent!"); constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front()); constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number{}); static_if<(NRemain > 1)>{}([&](auto fwd) { return scan_sequence_impl{}( scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{})); }).else_([&](auto fwd) { return fwd(scaned_seq); }); } }; template __host__ __device__ constexpr auto scan_sequence(Seq, Reduce) { constexpr auto scaned_seq = Sequence{}; constexpr auto remain_seq = Seq{}.PopFront(); constexpr index_t remain_size = Seq::GetSize() - 1; return scan_sequence_impl{}(scaned_seq, remain_seq, Reduce{}); } template __host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) { return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); }