Array.hip.hpp 1.52 KB
Newer Older
1
#pragma once
2
3
#include "Sequence.hip.hpp"
#include "functional.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <class TData, index_t NSize>
6
7
8
9
struct Array
{
    using Type = Array<TData, NSize>;

Chao Liu's avatar
Chao Liu committed
10
    static constexpr index_t nSize = NSize;
11

Chao Liu's avatar
Chao Liu committed
12
    index_t mData[nSize];
13
14

    template <class... Xs>
Chao Liu's avatar
Chao Liu committed
15
    __host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
16
17
18
    {
    }

19
20
21
    __host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }

    __host__ __device__ TData& operator[](index_t i) { return mData[i]; }
22
};
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
                                                     Sequence<IRs...> new2old)
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}([&](auto IDim) {
        constexpr index_t idim = IDim.Get();
        new_array[idim]        = old_array[new2old.Get(IDim)];
    });

    return new_array;
}

template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
                                                     Sequence<IRs...> old2new)
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}([&](auto IDim) {
        constexpr index_t idim       = IDim.Get();
        new_array[old2new.Get(IDim)] = old_array[idim];
    });

    return new_array;
}