"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0b44909c1a17b97522c57fb72a6309f22d5231dc"
Array.hpp 10.2 KB
Newer Older
1
2
3
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "Sequence.hpp"
#include "functional2.hpp"
6

7
8
namespace ck {

Chao Liu's avatar
Chao Liu committed
9
template <class TData, index_t NSize>
10
11
12
13
struct Array
{
    using Type = Array<TData, NSize>;

Chao Liu's avatar
Chao Liu committed
14
    static constexpr index_t nSize = NSize;
15

Chao Liu's avatar
Chao Liu committed
16
    index_t mData[nSize];
17
18

    template <class... Xs>
Chao Liu's avatar
Chao Liu committed
19
    __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
20
21
22
    {
    }

23
24
    __host__ __device__ constexpr index_t GetSize() const { return NSize; }

Chao Liu's avatar
Chao Liu committed
25
26
27
28
29
30
    template <index_t I>
    __host__ __device__ constexpr TData operator[](Number<I>) const
    {
        return mData[I];
    }

Chao Liu's avatar
Chao Liu committed
31
    __host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
32

Chao Liu's avatar
Chao Liu committed
33
34
35
36
37
38
39
    template <index_t I>
    __host__ __device__ TData& operator()(Number<I>)
    {
        return mData[I];
    }

    __host__ __device__ TData& operator()(index_t i) { return mData[i]; }
40

Chao Liu's avatar
Chao Liu committed
41
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
42
    __host__ __device__ constexpr void Set(Number<I>, TData x)
Chao Liu's avatar
Chao Liu committed
43
    {
Chao Liu's avatar
Chao Liu committed
44
45
        static_assert(I < NSize, "wrong!");

Chao Liu's avatar
Chao Liu committed
46
47
48
        mData[I] = x;
    }

Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    __host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }

    struct lambda_PushBack // emulate constexpr lambda
    {
        const Array<TData, NSize>& old_array;
        Array<TData, NSize + 1>& new_array;

        __host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
                                                      Array<TData, NSize + 1>& new_array_)
            : old_array(old_array_), new_array(new_array_)
        {
        }

        template <index_t I>
        __host__ __device__ constexpr void operator()(Number<I>) const
        {
            new_array.Set(Number<I>{}, old_array[I]);
        }
    };

Chao Liu's avatar
Chao Liu committed
69
    __host__ __device__ constexpr auto PushBack(TData x) const
70
71
72
    {
        Array<TData, NSize + 1> new_array;

Chao Liu's avatar
Chao Liu committed
73
        static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
74

Chao Liu's avatar
Chao Liu committed
75
        new_array.Set(Number<NSize>{}, x);
76
77
78

        return new_array;
    }
79
};
80

Chao Liu's avatar
Chao Liu committed
81
82
83
84
85
86
87
88
89
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
    return Array<index_t, sizeof...(Is)>{Is...};
}

template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
Chao Liu's avatar
Chao Liu committed
90
91
92
    constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
    constexpr auto zero_array    = sequence2array(zero_sequence);
    return zero_array;
Chao Liu's avatar
Chao Liu committed
93
94
}

95
template <class TData, index_t NSize, index_t... IRs>
Chao Liu's avatar
Chao Liu committed
96
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Chao Liu's avatar
Chao Liu committed
97
                                                               Sequence<IRs...> /*new2old*/)
98
99
100
{
    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

Chao Liu's avatar
Chao Liu committed
101
    static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
102

103
    return Array<TData, NSize>{old_array[IRs]...};
104
105
}

Chao Liu's avatar
Chao Liu committed
106
template <class TData, index_t NSize, class MapOld2New>
Chao Liu's avatar
Chao Liu committed
107
struct lambda_reorder_array_given_old2new
Chao Liu's avatar
Chao Liu committed
108
{
Chao Liu's avatar
Chao Liu committed
109
110
    const Array<TData, NSize>& old_array;
    Array<TData, NSize>& new_array;
Chao Liu's avatar
Chao Liu committed
111

Chao Liu's avatar
Chao Liu committed
112
113
114
    __host__ __device__ constexpr lambda_reorder_array_given_old2new(
        const Array<TData, NSize>& old_array_, Array<TData, NSize>& new_array_)
        : old_array(old_array_), new_array(new_array_)
Chao Liu's avatar
Chao Liu committed
115
116
117
118
119
120
    {
    }

    template <index_t IOldDim>
    __host__ __device__ constexpr void operator()(Number<IOldDim>) const
    {
Chao Liu's avatar
Chao Liu committed
121
        TData old_data = old_array[IOldDim];
Chao Liu's avatar
Chao Liu committed
122
123
124

        constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});

Chao Liu's avatar
Chao Liu committed
125
        new_array.Set(Number<INewDim>{}, old_data);
Chao Liu's avatar
Chao Liu committed
126
127
128
129
130
    }
};

template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Chao Liu's avatar
Chao Liu committed
131
                                                               Sequence<IRs...> /*old2new*/)
Chao Liu's avatar
Chao Liu committed
132
133
134
135
136
{
    Array<TData, NSize> new_array;

    static_assert(NSize == sizeof...(IRs), "NSize not consistent");

Chao Liu's avatar
Chao Liu committed
137
138
    static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
139
    static_for<0, NSize, 1>{}(
Chao Liu's avatar
Chao Liu committed
140
        lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
Chao Liu's avatar
Chao Liu committed
141
142
143

    return new_array;
}
Chao Liu's avatar
Chao Liu committed
144

145
template <class TData, index_t NSize, class ExtractSeq>
Chao Liu's avatar
Chao Liu committed
146
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
147
148
149
150
151
152
153
{
    Array<TData, ExtractSeq::GetSize()> new_array;

    constexpr index_t new_size = ExtractSeq::GetSize();

    static_assert(new_size <= NSize, "wrong! too many extract");

Chao Liu's avatar
Chao Liu committed
154
    static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
155
156
157
158

    return new_array;
}

Chao Liu's avatar
Chao Liu committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
struct lambda_array_math
{
    const F& f;
    const X& x;
    const Y& y;
    Z& z;

    __host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
        : f(f_), x(x_), y(y_), z(z_)
    {
    }

    template <index_t IDim_>
    __host__ __device__ constexpr void operator()(Number<IDim_>) const
    {
        constexpr auto IDim = Number<IDim_>{};

        z.Set(IDim, f(x[IDim], y[IDim]));
    }
};

181
// Array = Array + Array
Chao Liu's avatar
Chao Liu committed
182
template <class TData, index_t NSize>
Chao Liu's avatar
Chao Liu committed
183
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
Chao Liu's avatar
Chao Liu committed
184
185
186
{
    Array<TData, NSize> result;

187
    auto f = math::plus<index_t>{};
Chao Liu's avatar
Chao Liu committed
188

Chao Liu's avatar
Chao Liu committed
189
190
191
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
Chao Liu's avatar
Chao Liu committed
192
193
194

    return result;
}
Chao Liu's avatar
Chao Liu committed
195

196
197
198
199
200
201
// Array = Array - Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{
    Array<TData, NSize> result;

202
    auto f = math::minus<index_t>{};
203

Chao Liu's avatar
Chao Liu committed
204
205
206
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
207
208
209
210
211
212
213
214
215
216
217
218

    return result;
}

// Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

219
    auto f = math::plus<index_t>{};
220

Chao Liu's avatar
Chao Liu committed
221
222
223
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
224
225
226
227
228
229
230
231
232
233
234
235

    return result;
}

// Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

236
    auto f = math::minus<index_t>{};
237

Chao Liu's avatar
Chao Liu committed
238
239
240
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
241
242
243
244

    return result;
}

Chao Liu's avatar
Chao Liu committed
245
246
247
248
249
250
251
252
// Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;

253
    auto f = math::multiplies<index_t>{};
Chao Liu's avatar
Chao Liu committed
254

Chao Liu's avatar
Chao Liu committed
255
256
257
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
Chao Liu's avatar
Chao Liu committed
258
259
260

    return result;
}
261

262
263
264
// Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
265
{
266
267
268
    static_assert(sizeof...(Is) == NSize, "wrong! size not the same");

    Array<TData, NSize> result;
269

270
    auto f = math::minus<index_t>{};
271

Chao Liu's avatar
Chao Liu committed
272
273
274
    static_for<0, NSize, 1>{}(
        lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
            f, a, b, result));
275
276
277
278
279
280
281
282
283
284
285
286

    return result;
}

template <class TData, index_t NSize, class Reduce>
__host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{
    TData result = init;

    static_assert(NSize > 0, "wrong");

Chao Liu's avatar
Chao Liu committed
287
    static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
288
289
290

    return result;
}
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
    constexpr index_t nsize = a.GetSize();

    static_assert(nsize > 0 && nsize <= 10, "wrong!");

    static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });

    static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });

    static_if<nsize == 3>{}(
        [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });

    static_if<nsize == 4>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });

    static_if<nsize == 5>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
    });

    static_if<nsize == 6>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
    });

    static_if<nsize == 7>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6]);
    });

    static_if<nsize == 8>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7]);
    });

    static_if<nsize == 9>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7],
               a[8]);
    });

    static_if<nsize == 10>{}([&](auto) {
        printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
               s,
               nsize,
               a[0],
               a[1],
               a[2],
               a[3],
               a[4],
               a[5],
               a[6],
               a[7],
               a[8],
               a[9]);
    });
}
375
376
377

} // namespace ck
#endif