Array.hip.hpp 8.98 KB
Newer Older
1
#pragma once
2
#include "Sequence.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3
#include "functional2.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <class TData, index_t NSize>
6
7
8
9
struct Array
{
    using Type = Array<TData, NSize>;

Chao Liu's avatar
Chao Liu committed
10
    static constexpr index_t nSize = NSize;
11

Chao Liu's avatar
Chao Liu committed
12
    index_t mData[nSize];
13
14

    template <class... Xs>
Chao Liu's avatar
Chao Liu committed
15
    __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
16
17
18
    {
    }

19
20
    __host__ __device__ constexpr index_t GetSize() const { return NSize; }

Chao Liu's avatar
Chao Liu committed
21
22
23
24
25
26
    template <index_t I>
    __host__ __device__ constexpr TData operator[](Number<I>) const
    {
        return mData[I];
    }

Chao Liu's avatar
Chao Liu committed
27
    __host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
28

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
34
35
    template <index_t I>
    __host__ __device__ TData& operator()(Number<I>)
    {
        return mData[I];
    }

    __host__ __device__ TData& operator()(index_t i) { return mData[i]; }
36

Chao Liu's avatar
Chao Liu committed
37
38
39
    template <index_t I>
    __host__ __device__ constexpr TData Get(Number<I>) const
    {
Chao Liu's avatar
Chao Liu committed
40
41
        static_assert(I < NSize, "wrong!");

Chao Liu's avatar
Chao Liu committed
42
43
44
45
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
46
    __host__ __device__ constexpr void Set(Number<I>, TData x)
Chao Liu's avatar
Chao Liu committed
47
    {
Chao Liu's avatar
Chao Liu committed
48
49
        static_assert(I < NSize, "wrong!");

Chao Liu's avatar
Chao Liu committed
50
51
52
53
        mData[I] = x;
    }

    __host__ __device__ constexpr auto PushBack(TData x) const
54
55
56
    {
        Array<TData, NSize + 1> new_array;

Chao Liu's avatar
Chao Liu committed
57
        static_for<0, NSize, 1>{}([&](auto I) {
58
            constexpr index_t i = I.Get();
Chao Liu's avatar
Chao Liu committed
59
            new_array(i)        = mData[i];
60
61
        });

Chao Liu's avatar
Chao Liu committed
62
        new_array(NSize) = x;
63
64
65

        return new_array;
    }
66
};
67

Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
73
74
75
76
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
    return Array<index_t, sizeof...(Is)>{Is...};
}

template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
Chao Liu's avatar
Chao Liu committed
77
78
79
    constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
    constexpr auto zero_array    = sequence2array(zero_sequence);
    return zero_array;
Chao Liu's avatar
Chao Liu committed
80
81
}

82
template <class TData, index_t NSize, index_t... IRs>
Chao Liu's avatar
Chao Liu committed
83
84
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
                                                               Sequence<IRs...> new2old)
85
86
87
88
89
90
91
92
93
94
95
96
97
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}([&](auto IDim) {
        constexpr index_t idim = IDim.Get();
        new_array[idim]        = old_array[new2old.Get(IDim)];
    });

    return new_array;
}

Chao Liu's avatar
Chao Liu committed
98
template <class TData, index_t NSize, class MapOld2New>
Chao Liu's avatar
Chao Liu committed
99
struct lambda_reorder_array_given_old2new
Chao Liu's avatar
Chao Liu committed
100
{
Chao Liu's avatar
Chao Liu committed
101
102
    const Array<TData, NSize>& old_array;
    Array<TData, NSize>& new_array;
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
105
106
    __host__ __device__ constexpr lambda_reorder_array_given_old2new(
        const Array<TData, NSize>& old_array_, Array<TData, NSize>& new_array_)
        : old_array(old_array_), new_array(new_array_)
Chao Liu's avatar
Chao Liu committed
107
108
109
110
111
112
    {
    }

    template <index_t IOldDim>
    __host__ __device__ constexpr void operator()(Number<IOldDim>) const
    {
Chao Liu's avatar
Chao Liu committed
113
        TData old_data = old_array[IOldDim];
Chao Liu's avatar
Chao Liu committed
114
115
116

        constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});

Chao Liu's avatar
Chao Liu committed
117
        new_array.Set(Number<INewDim>{}, old_data);
Chao Liu's avatar
Chao Liu committed
118
119
120
121
122
123
124
125
126
127
128
129
    }
};

template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
                                                               Sequence<IRs...> old2new)
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

    static_for<0, NSize, 1>{}(
Chao Liu's avatar
Chao Liu committed
130
        lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
Chao Liu's avatar
Chao Liu committed
131
132
133

    return new_array;
}
Chao Liu's avatar
Chao Liu committed
134

135
template <class TData, index_t NSize, class ExtractSeq>
Chao Liu's avatar
Chao Liu committed
136
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
137
138
139
140
141
142
143
144
145
{
    Array<TData, ExtractSeq::GetSize()> new_array;

    constexpr index_t new_size = ExtractSeq::GetSize();

    static_assert(new_size <= NSize, "wrong! too many extract");

    static_for<0, new_size, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();
Chao Liu's avatar
Chao Liu committed
146
        new_array(i)        = old_array[ExtractSeq::Get(I)];
147
148
149
150
151
    });

    return new_array;
}

152
// Array = Array + Array
Chao Liu's avatar
Chao Liu committed
153
template <class TData, index_t NSize>
Chao Liu's avatar
Chao Liu committed
154
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
Chao Liu's avatar
Chao Liu committed
155
156
157
158
159
{
    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();
Chao Liu's avatar
Chao Liu committed
160

Chao Liu's avatar
Chao Liu committed
161
        result(i) = a[i] + b[i];
Chao Liu's avatar
Chao Liu committed
162
163
164
165
    });

    return result;
}
Chao Liu's avatar
Chao Liu committed
166

167
168
169
170
171
172
173
174
175
// Array = Array - Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{
    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

Chao Liu's avatar
Chao Liu committed
176
        result(i) = a[i] - b[i];
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    });

    return result;
}

// Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

Chao Liu's avatar
Chao Liu committed
193
        result(i) = a[i] + b.Get(I);
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    });

    return result;
}

// Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

Chao Liu's avatar
Chao Liu committed
210
        result(i) = a[i] - b.Get(I);
211
212
213
214
215
    });

    return result;
}

Chao Liu's avatar
Chao Liu committed
216
217
218
219
220
221
222
223
224
225
226
// Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

Chao Liu's avatar
Chao Liu committed
227
        result(i) = a[i] * b.Get(I);
Chao Liu's avatar
Chao Liu committed
228
229
230
231
    });

    return result;
}
232

233
234
235
// Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
236
{
237
238
239
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;
240

241
242
243
    static_for<0, NSize, 1>{}([&](auto I) {
        constexpr index_t i = I.Get();

Chao Liu's avatar
Chao Liu committed
244
        result(i) = a.Get(I) - b[i];
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    });

    return result;
}

template <class TData, index_t NSize, class Reduce>
__host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{
    TData result = init;

    static_assert(NSize > 0, "wrong");

    static_for<0, NSize, 1>{}([&](auto I) {
259
260
261
262
263
264
        constexpr index_t i = I.Get();
        result              = f(result, a[i]);
    });

    return result;
}
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
    constexpr index_t nsize = a.GetSize();

    static_assert(nsize > 0 && nsize <= 10, "wrong!");

    static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });

    static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });

    static_if<nsize == 3>{}(
        [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });

    static_if<nsize == 4>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });

    static_if<nsize == 5>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
    });

    static_if<nsize == 6>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
    });

    static_if<nsize == 7>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6]);
    });

    static_if<nsize == 8>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7]);
    });

    static_if<nsize == 9>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7],
               a[8]);
    });

    static_if<nsize == 10>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7],
               a[8],
               a[9]);
    });
}