#pragma once #include "Sequence.hip.hpp" #include "functional2.hip.hpp" template struct Array { using Type = Array; static constexpr index_t nSize = NSize; index_t mData[nSize]; template __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast(xs)...} { } __host__ __device__ constexpr index_t GetSize() const { return NSize; } template __host__ __device__ constexpr TData operator[](Number) const { return mData[I]; } __host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; } template __host__ __device__ TData& operator()(Number) { return mData[I]; } __host__ __device__ TData& operator()(index_t i) { return mData[i]; } template __host__ __device__ constexpr TData Get(Number) const { static_assert(I < NSize, "wrong!"); return mData[I]; } template __host__ __device__ constexpr void Set(Number, TData x) { static_assert(I < NSize, "wrong!"); mData[I] = x; } __host__ __device__ constexpr auto PushBack(TData x) const { Array new_array; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); new_array(i) = mData[i]; }); new_array(NSize) = x; return new_array; } }; template __host__ __device__ constexpr auto sequence2array(Sequence) { return Array{Is...}; } template __host__ __device__ constexpr auto make_zero_array() { constexpr auto zero_sequence = typename uniform_sequence_gen::SeqType{}; constexpr auto zero_array = sequence2array(zero_sequence); return zero_array; } template __host__ __device__ constexpr auto reorder_array_given_new2old(const Array& old_array, Sequence new2old) { Array new_array; static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_for<0, NSize, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); new_array[idim] = old_array[new2old.Get(IDim)]; }); return new_array; } template struct lambda_reorder_array_given_old2new { const Array& old_array; Array& new_array; __host__ __device__ constexpr lambda_reorder_array_given_old2new( const Array& old_array_, Array& new_array_) : old_array(old_array_), new_array(new_array_) { } template __host__ __device__ constexpr void operator()(Number) const { TData old_data = old_array[IOldDim]; constexpr index_t INewDim = MapOld2New::Get(Number{}); new_array.Set(Number{}, old_data); } }; template __host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, Sequence old2new) { Array new_array; static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_for<0, NSize, 1>{}( lambda_reorder_array_given_old2new>(old_array, new_array)); return new_array; } template __host__ __device__ constexpr auto extract_array(const Array& old_array, ExtractSeq) { Array new_array; constexpr index_t new_size = ExtractSeq::GetSize(); static_assert(new_size <= NSize, "wrong! too many extract"); static_for<0, new_size, 1>{}([&](auto I) { constexpr index_t i = I.Get(); new_array(i) = old_array[ExtractSeq::Get(I)]; }); return new_array; } // Array = Array + Array template __host__ __device__ constexpr auto operator+(Array a, Array b) { Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a[i] + b[i]; }); return result; } // Array = Array - Array template __host__ __device__ constexpr auto operator-(Array a, Array b) { Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a[i] - b[i]; }); return result; } // Array = Array + Sequence template __host__ __device__ constexpr auto operator+(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a[i] + b.Get(I); }); return result; } // Array = Array - Sequence template __host__ __device__ constexpr auto operator-(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a[i] - b.Get(I); }); return result; } // Array = Array * Sequence template __host__ __device__ constexpr auto operator*(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a[i] * b.Get(I); }); return result; } // Array = Sequence - Array template __host__ __device__ constexpr auto operator-(Sequence a, Array b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); Array result; static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result(i) = a.Get(I) - b[i]; }); return result; } template __host__ __device__ constexpr TData accumulate_on_array(const Array& a, Reduce f, TData init) { TData result = init; static_assert(NSize > 0, "wrong"); static_for<0, NSize, 1>{}([&](auto I) { constexpr index_t i = I.Get(); result = f(result, a[i]); }); return result; } template __host__ __device__ void print_Array(const char* s, Array a) { constexpr index_t nsize = a.GetSize(); static_assert(nsize > 0 && nsize <= 10, "wrong!"); static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); }); static_if{}( [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); }); static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]); }); }