#pragma once #include "integral_constant.hip.hpp" #include "functional.hip.hpp" template struct Sequence { using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ static constexpr index_t Get(Number) { static_assert(I < mSize, "wrong! I too large"); // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ constexpr index_t operator[](Number) const { static_assert(I < mSize, "wrong! I too large"); const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } // make sure I is constepxr __host__ __device__ constexpr index_t operator[](index_t I) const { const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { #if 0 // require sequence_sort, which is not implemented yet static_assert(is_same>::SortedSeqType, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, "wrong! invalid new2old map"); #endif static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! new2old map should have the same size as Sequence to be rerodered"); return Sequence{})...>{}; } #if 0 // require sequence_sort, which is not implemented yet template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/) { #if 0 static_assert(is_same::SortedSeqType, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, "wrong! invalid old2new map"); #endif constexpr auto map_new2old = typename sequence_map_inverse::SeqMapType{}; return ReorderGivenNew2Old(map_new2old); } #endif __host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr index_t Front() { const index_t mData[mSize + 1] = {Is..., 0}; return mData[0]; } __host__ __device__ static constexpr index_t Back() { const index_t mData[mSize + 1] = {Is..., 0}; return mData[mSize - 1]; } template __host__ __device__ static constexpr auto PushFront(Number) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number) { return Sequence{}; } __host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopBack(); template __host__ __device__ static constexpr auto Append(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number); }; // merge sequence template struct sequence_merge; template struct sequence_merge, Sequence> { using SeqType = Sequence; }; // arithmetic sqeuence template struct arithmetic_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; using SeqType = typename sequence_merge< typename arithmetic_sequence_gen_impl::SeqType, typename arithmetic_sequence_gen_impl::SeqType>::SeqType; }; template struct arithmetic_sequence_gen_impl { using SeqType = Sequence; }; template struct arithmetic_sequence_gen_impl { using SeqType = Sequence<>; }; template struct arithmetic_sequence_gen { using SeqType = typename arithmetic_sequence_gen_impl::SeqType; }; // transform sequence template struct sequence_transform; template struct sequence_transform> { using SeqType = Sequence; }; // uniform sequence template struct uniform_sequence_gen { struct return_constant { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; using SeqType = typename sequence_transform< return_constant, typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType; }; // reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::SeqType; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using SeqType = typename sequence_merge, old_scan>::SeqType; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using SeqType = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using SeqType = Sequence<>; }; // extract sequence template struct sequence_extract; template struct sequence_extract> { using SeqType = Sequence{})...>; }; // split sequence template struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType; using range1 = typename arithmetic_sequence_gen::SeqType; using SeqType0 = typename sequence_extract::SeqType; using SeqType1 = typename sequence_extract::SeqType; }; // reverse sequence template struct sequence_reverse { static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; using SeqType = typename sequence_merge< typename sequence_reverse::SeqType, typename sequence_reverse::SeqType>::SeqType; }; template struct sequence_reverse> { using SeqType = Sequence; }; template struct sequence_reverse> { using SeqType = Sequence; }; #if 0 // not fully implemented template struct sequence_sort_merge_impl; template struct sequence_sort_merge_impl, Sequence, Sequence, Sequence> { }; template struct sequence_sort; template struct sequence_sort> { using OriginalSeqType = Sequence; using SortedSeqType = xxxxx; using MapSorted2OriginalType = xxx; }; template struct sequence_map_inverse_impl; // impl for valid map, no impl for invalid map template struct sequence_map_inverse_impl, true> { using SeqMapType = sequence_sort>::MapSorted2OriginalType; }; template struct sequence_map_inverse; template struct sequence_map_inverse> { // TODO: make sure the map to be inversed is valid: [0, sizeof...(Is)) static constexpr bool is_valid_sequence_map = is_same>::SortedSeqType, typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value; // make compiler fails, if is_valid_map != true using SeqMapType = typename sequence_map_inverse_impl, is_valid_map>::SeqMapType; }; #endif template struct is_valid_sequence_map { static constexpr bool value = #if 0 // sequence_sort is not implemented yet is_same::SeqType, typename sequence_sort::SortedSeqType>::value; #else true; #endif }; template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq{}.Reverse()).Reverse(); } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::SeqType{}; } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } template __host__ __device__ constexpr auto Sequence::PopFront() { return sequence_pop_front(Type{}); } template __host__ __device__ constexpr auto Sequence::PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ constexpr auto Sequence::Reverse() { return typename sequence_reverse>::SeqType{}; } template template __host__ __device__ constexpr auto Sequence::Modify(Number, Number) { static_assert(I < GetSize(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); return seq_left.PushBack(Number{}).Append(seq_right); } template __host__ __device__ void print_Sequence(const char* s, Sequence) { constexpr index_t nsize = Sequence::GetSize(); static_assert(nsize <= 10, "wrong!"); static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); }