#pragma once #include "constant_integral.hip.hpp" #include "functional.hip.hpp" struct EmptySequence { template __host__ __device__ constexpr Seq Append(Seq) const { return {}; } }; template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); const index_t mData[mSize] = {Is...}; __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } template __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { static_assert(mSize == sizeof...(IRs), "mSize not consistent"); constexpr auto old = Type{}; return Sequence{})...>{}; } template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // TODO: don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } __host__ __device__ constexpr auto Reverse() const { // not implemented } __host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } __host__ __device__ constexpr auto PopFront() const; __host__ __device__ constexpr auto PopBack() const; template __host__ __device__ constexpr auto Append(Sequence) const { return Sequence{}; } __host__ __device__ constexpr auto Append(EmptySequence) const { return Type{}; } template __host__ __device__ constexpr auto Extract(Number...) const { return Sequence)...>{}; } template struct split_impl { template __host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const { constexpr new_first = FirstSeq{}.PushBack(Number{}); constexpr new_second = SecondSeq{}.PopFront(); static_if<(N > 0)>{}([&](auto fwd) { return split_impl{}(new_first, fwd(new_second)); }).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); }); } }; // split one sequence to two sequnces: [0, I) and [I, nSize) // return type is std::pair template __host__ __device__ constexpr auto Split(Number) const { static_assert(I <= nSize, "wrong! split position is too high!"); static_if<(I == 0)>{}( [&](auto fwd) { return std::make_pair(EmptySequence<>{}, fwd(Type{})); }); static_if<(I == nSize)>{}( [&](auto fwd) { return std::make_pair(Type<>{}, fwd(EmptySequence<>{})); }); static_if<(I > 0 && I < nSize)>{}([&](auto fforwader) { constexpr auto first = Sequence {} constexpr auto second = Type{}.PopFront(); return split_impl{}(first, fwd(second)); }); } template __host__ __device__ constexpr auto Modify(Number, Number) const { constexpr auto first_second = Split(Number{}); constexpr auto left = first_second.first; constexpr auto right = first_second.second.PopFront(); return left.PushBack(Number{}).Append(right); } }; template __host__ __device__ auto make_increasing_sequence(Number, Number, Number) { static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!"); // not implemented } template __host__ __device__ auto make_uniform_sequence(Number, Number); { // not implemented } template __host__ __device__ constexpr auto operator+(Sequence, Sequence) const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_for<0, xs.GetSize(), 1>{}([&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I)); }); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence)const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) const { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) const { return seq_x + make_uniform_sequence(Number, Number{}); } template __host__ __device__ constexpr auto operator-(Sequence, Number) const { return seq_x - make_uniform_sequence(Number, Number{}); } template __host__ __device__ constexpr auto operator*(Sequence, Number)const { return seq_x * make_uniform_sequence(Number, Number{}); } template __host__ __device__ constexpr auto operator/(Sequence, Number) const { return seq_x / make_uniform_sequence(Number, Number{}); } template __host__ __device__ constexpr auto operator%(Sequence seq_x, Number y) const { return seq_x % make_uniform_sequence(Number, Number{}); } template __host__ __device__ constexpr auto operator+(Number, Sequence) const { return make_uniform_sequence(Number{}, Number{}) + Sequence{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) const { return make_uniform_sequence(Number{}, Number{}) - Sequence{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence)const { return make_uniform_sequence(Number{}, Number{}) * Sequence{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) const { return make_uniform_sequence(Number{}, Number{}) / Sequence{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) const { return make_uniform_sequence(Number{}, Number{}) % Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #if 0 // TODO: for some reason, compiler cannot instantiate this template template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #else // TODO: delete these very ugly mess template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { return Sequence{}; } #endif #if 1 // TODO: fix these mess template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } #else // TODO:: these doesn't compile template struct transform_sequences_impl { template __host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const { static_assert(NRemain > 1, "wrong! should have NRemain > 1"); constexpr index_t N = f(Xs{}.Get(Number<0>{})...); constexpr auto y_new = y.PushBack(Number{}); return transform_sequences_impl{}(f, y_new, xs.PopFront()...); } }; template <> struct transform_sequences_impl<1> { template __host__ __device__ constexpr auto operator()(F f, Y, Xs...) const { constexpr index_t N = f(Xs{}.Get(Number<0>{})...); return Y{}.PushBack(Number{}); } }; template __host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs) { constexpr index_t nSize = X::GetSize(); constexpr auto I0 = Number<0>{}; constexpr auto y0 = Sequence{}; return transform_sequences_impl{}(f, y0, x.PopFront(), xs.PopFront()...); } #endif template __host__ __device__ constexpr auto Sequence::PopFront() const { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); } template struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } template struct scan_sequence_impl { template __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const { static_assert(RemainSeq{}.GetSize() == NRemain, "wrong! RemainSeq and NRemain not consistent!"); constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front()); constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number{}); static_if<(NRemain > 1)>{}([&](auto fwd) { return scan_sequence_impl{}( scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{})); }).else_([&](auto fwd) { return fwd(scaned_seq); }); } }; template __host__ __device__ constexpr auto scan_sequence(Seq, Reduce) { constexpr auto scaned_seq = Sequence{}; constexpr auto remain_seq = Seq{}.PopFront(); constexpr index_t remain_size = Seq::GetSize() - 1; return scan_sequence_impl{}(scaned_seq, remain_seq, Reduce{}); } template __host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) { return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); }