Array.hip.hpp 3.55 KB
Newer Older
1
#pragma once
2
3
#include "Sequence.hip.hpp"
#include "functional.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <class TData, index_t NSize>
6
7
8
9
struct Array
{
    using Type = Array<TData, NSize>;

Chao Liu's avatar
Chao Liu committed
10
    static constexpr index_t nSize = NSize;
11

Chao Liu's avatar
Chao Liu committed
12
    index_t mData[nSize];
13
14

    template <class... Xs>
Chao Liu's avatar
Chao Liu committed
15
    __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
16
17
18
    {
    }

19
20
    __host__ __device__ constexpr index_t GetSize() const { return NSize; }

21
22
23
    __host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }

    __host__ __device__ TData& operator[](index_t i) { return mData[i]; }
24
25
26
27
28

    __host__ __device__ auto PushBack(TData x) const
    {
        Array<TData, NSize + 1> new_array;

Chao Liu's avatar
Chao Liu committed
29
        static_for<0, NSize, 1>{}([&](auto I) {
30
31
32
33
34
35
36
37
            constexpr index_t i = I.Get();
            new_array[i]        = mData[i];
        });

        new_array[NSize] = x;

        return new_array;
    }
38
};
39

Chao Liu's avatar
Chao Liu committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
    return Array<index_t, sizeof...(Is)>{Is...};
}

template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
    Array<TData, NSize> a;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();
        a[i]                = static_cast<TData>(0);
    });

    return a;
}

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
                                                     Sequence<IRs...> new2old)
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}([&](auto IDim) {
        constexpr index_t idim = IDim.Get();
        new_array[idim]        = old_array[new2old.Get(IDim)];
    });

    return new_array;
}

template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
                                                     Sequence<IRs...> old2new)
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}([&](auto IDim) {
        constexpr index_t idim       = IDim.Get();
        new_array[old2new.Get(IDim)] = old_array[idim];
    });

    return new_array;
89
}
Chao Liu's avatar
Chao Liu committed
90

91
92
93
94
95
96
97
98
99
100
101
template <class TData, index_t NSize, class ExtractSeq>
__host__ __device__ auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{
    Array<TData, ExtractSeq::GetSize()> new_array;

    constexpr index_t new_size = ExtractSeq::GetSize();

    static_assert(new_size <= NSize, "wrong! too many extract");

    static_for<0, new_size, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();
Chao Liu's avatar
Chao Liu committed
102
        new_array[i]        = old_array[ExtractSeq::Get(I)];
103
104
105
106
107
    });

    return new_array;
}

Chao Liu's avatar
Chao Liu committed
108
template <class TData, index_t NSize>
Chao Liu's avatar
Chao Liu committed
109
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
Chao Liu's avatar
Chao Liu committed
110
111
112
113
114
115
116
117
118
119
{
    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();
        result[i]           = a[i] + b[i];
    });

    return result;
}
Chao Liu's avatar
Chao Liu committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

// Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

        result[i] = a[i] + b.Get(I);
    });

    return result;
}