#pragma once #include "constant_integral.hip.hpp" #include "functional.hip.hpp" template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); const index_t mData[mSize] = {Is...}; __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } template __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { static_assert(mSize == sizeof...(IRs), "mSize not consistent"); constexpr auto old = Type{}; return Sequence{})...>{}; } template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } template __host__ __device__ constexpr auto PushFront(Number) const { return Sequence{}; } template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } __host__ __device__ constexpr auto PopFront() const; __host__ __device__ constexpr auto PopBack() const; template __host__ __device__ constexpr auto Transform(F f) const { return Sequence{}; } }; template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } #if 1 // this is ugly, only for 2 sequences template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } // this is ugly, only for 3 sequences template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } #else template struct transform_sequences_impl { template __host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const { static_assert(NRemain > 1, "wrong! should have NRemain > 1"); constexpr index_t N = f(Xs{}.Get(Number<0>{})...); constexpr auto y_new = y.PushBack(Number{}); return transform_sequences_impl{}(f, y_new, xs.PopFront()...); } }; template <> struct transform_sequences_impl<1> { template __host__ __device__ constexpr auto operator()(F f, Y, Xs...) const { constexpr index_t N = f(Xs{}.Get(Number<0>{})...); return Y{}.PushBack(Number{}); } }; template __host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs) { constexpr index_t nSize = X::GetSize(); constexpr auto I0 = Number<0>{}; constexpr auto y0 = Sequence{}; return transform_sequences_impl{}(f, y0, x.PopFront(), xs.PopFront()...); } #endif template __host__ __device__ constexpr auto Sequence::PopFront() const { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); } template struct accumulate_on_sequence_f { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_f{}, Reduce{}); return Reduce{}(a, I); }