#pragma once #include "integral_constant.hip.hpp" #include "functional.hip.hpp" template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ static constexpr index_t Get(Number) { static_assert(I < mSize, "wrong! I too large"); // the last dummy element is to prevent compiler complain about empty Sequence const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { #if 0 // require sequence_sort, which is not implemented yet static_assert(is_same>::SortedSeqType, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, "wrong! invalid new2old map"); #endif return Sequence{})...>{}; } #if 0 // require sequence_sort, which is not implemented yet template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/) { static_assert(is_same::SortedSeqType, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, "wrong! invalid old2new map"); constexpr auto map_new2old = typename sequence_map_inverse::SeqMapType{}; return ReorderGivenNew2Old(map_new2old); } #endif __host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr index_t Front() { const index_t mData[mSize + 1] = {Is..., 0}; return mData[0]; } __host__ __device__ static constexpr index_t Back() { const index_t mData[mSize + 1] = {Is..., 0}; return mData[mSize - 1]; } template __host__ __device__ static constexpr auto PushFront(Number) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number) { return Sequence{}; } __host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopBack(); template __host__ __device__ static constexpr auto Append(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number); }; template struct sequence_merge; template struct sequence_merge, Sequence> { using SeqType = Sequence; }; template struct arithmetic_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; using SeqType = typename sequence_merge< typename arithmetic_sequence_gen_impl::SeqType, typename arithmetic_sequence_gen_impl::SeqType>::SeqType; }; template struct arithmetic_sequence_gen_impl { using SeqType = Sequence; }; template struct arithmetic_sequence_gen_impl { using SeqType = Sequence<>; }; template struct arithmetic_sequence_gen { using SeqType = typename arithmetic_sequence_gen_impl::SeqType; }; template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce>::SeqType; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using SeqType = typename sequence_merge, old_scan>::SeqType; }; template struct sequence_reverse_inclusive_scan, Reduce> { using SeqType = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce> { using SeqType = Sequence<>; }; template struct sequence_extract; template struct sequence_extract> { using SeqType = Sequence{})...>; }; template struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType; using range1 = typename arithmetic_sequence_gen::SeqType; using SeqType0 = typename sequence_extract::SeqType; using SeqType1 = typename sequence_extract::SeqType; }; template struct sequence_reverse { static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; using SeqType = typename sequence_merge< typename sequence_reverse::SeqType, typename sequence_reverse::SeqType>::SeqType; }; template struct sequence_reverse> { using SeqType = Sequence; }; template struct sequence_reverse> { using SeqType = Sequence; }; #if 0 // not fully implemented template struct sequence_sort_merge_impl; template struct sequence_sort_merge_impl, Sequence, Sequence, Sequence> { }; template struct sequence_sort; template struct sequence_sort> { using OriginalSeqType = Sequence; using SortedSeqType = xxxxx; using MapSorted2OriginalType = xxx; }; template struct sequence_map_inverse_impl; // impl for valid map, no impl for invalid map template struct sequence_map_inverse_impl, true> { using SeqMapType = sequence_sort>::MapSorted2OriginalType; }; template struct sequence_map_inverse; template struct sequence_map_inverse> { // TODO: make sure the map to be inversed is valid: [0, sizeof...(Is)) static constexpr bool is_valid_sequence_map = is_same>::SortedSeqType, typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value; // make compiler fails, if is_valid_map != true using SeqMapType = typename sequence_map_inverse_impl, is_valid_map>::SeqMapType; }; #endif template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_for<0, seq_x.GetSize(), 1>{}( [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { #if 0 // TODO: turn it on. Doesn't compile constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); }); #endif return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); }); return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq{}.Reverse()).Reverse(); } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce) { return typename sequence_reverse_inclusive_scan::SeqType{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse(); } template struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const { return Seq{}.Get(IDim{}); } }; template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } template __host__ __device__ constexpr auto Sequence::PopFront() { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ constexpr auto Sequence::Reverse() { return typename sequence_reverse>::SeqType{}; } template template __host__ __device__ constexpr auto Sequence::Modify(Number, Number) { static_assert(I < GetSize(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); return seq_left.PushBack(Number{}).Append(seq_right); }