#pragma once #include "Sequence.hip.hpp" #include "functional.hip.hpp" template struct Array { using Type = Array; static constexpr index_t nSize = NSize; index_t mData[nSize]; template __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast(xs)...} { } __host__ __device__ constexpr index_t GetSize() const { return NSize; } __host__ __device__ const TData& operator[](index_t i) const { return mData[i]; } __host__ __device__ TData& operator[](index_t i) { return mData[i]; } __host__ __device__ auto PushBack(TData x) const { Array new_array; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); new_array[i] = mData[i]; }); new_array[NSize] = x; return new_array; } }; template __host__ __device__ constexpr auto sequence2array(Sequence) { return Array{Is...}; } template __host__ __device__ constexpr auto make_zero_array() { Array a; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); a[i] = static_cast(0); }); return a; } template __host__ __device__ constexpr auto reorder_array_given_new2old(const Array& old_array, Sequence new2old) { Array new_array; static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_for<0, NSize, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); new_array[idim] = old_array[new2old.Get(IDim)]; }); return new_array; } template __host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, Sequence old2new) { Array new_array; static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_for<0, NSize, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); new_array[old2new.Get(IDim)] = old_array[idim]; }); return new_array; } template __host__ __device__ constexpr auto extract_array(const Array& old_array, ExtractSeq) { Array new_array; constexpr index_t new_size = ExtractSeq::GetSize(); static_assert(new_size <= NSize, "wrong! too many extract"); static_for<0, new_size, 1>{}([&](auto I) { constexpr index_t i = I.Get(); new_array[i] = old_array[ExtractSeq::Get(I)]; }); return new_array; } // Array = Array + Array template __host__ __device__ constexpr auto operator+(Array a, Array b) { Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a[i] + b[i]; }); return result; } // Array = Array - Array template __host__ __device__ constexpr auto operator-(Array a, Array b) { Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a[i] - b[i]; }); return result; } // Array = Array + Sequence template __host__ __device__ constexpr auto operator+(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a[i] + b.Get(I); }); return result; } // Array = Array - Sequence template __host__ __device__ constexpr auto operator-(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a[i] - b.Get(I); }); return result; } // Array = Array * Sequence template __host__ __device__ constexpr auto operator*(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a[i] * b.Get(I); }); return result; } // Array = Sequence - Array template __host__ __device__ constexpr auto operator-(Sequence a, Array b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result[i] = a.Get(I) - b[i]; }); return result; } template __host__ __device__ constexpr TData accumulate_on_array(const Array& a, Reduce f, TData init) { TData result = init; static_assert(NSize > 0, "wrong"); static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result = f(result, a[i]); }); return result; } template __host__ __device__ void print_Array(const char* s, Array a) { constexpr index_t nsize = a.GetSize(); static_assert(nsize > 0 && nsize <= 10, "wrong!"); static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]); }); }