Sequence.hpp 16 KB
Newer Older
1
2
3
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "integral_constant.hpp"
#include "functional.hpp"
6

7
8
namespace ck {

Chao Liu's avatar
Chao Liu committed
9
10
11
template <index_t, index_t, index_t>
struct static_for;

12
13
14
15
16
17
template <index_t...>
struct Sequence;

template <class Seq, index_t I>
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
18
template <class>
19
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
20

Chao Liu's avatar
Chao Liu committed
21
22
23
template <class>
struct sequence_map_inverse;

24
25
26
27
28
29
30
31
32
template <class>
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
33
template <index_t... Is>
34
35
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
36
37
    using Type      = Sequence;
    using data_type = index_t;
38

39
    static constexpr index_t mSize = sizeof...(Is);
40

Chao Liu's avatar
Chao Liu committed
41
    __host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
42

Chao Liu's avatar
Chao Liu committed
43
    __host__ __device__ static constexpr index_t GetImpl(index_t I)
44
    {
Chao Liu's avatar
Chao Liu committed
45
46
47
48
49
50
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
51
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
52
    {
Chao Liu's avatar
Chao Liu committed
53
54
55
        static_assert(I < mSize, "wrong! I too large");

        return Number<GetImpl(Number<I>{})>{};
Chao Liu's avatar
Chao Liu committed
56
57
    }

Chao Liu's avatar
Chao Liu committed
58
59
    __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }

Chao Liu's avatar
Chao Liu committed
60
61
    template <index_t I>
    __host__ __device__ constexpr auto operator[](Number<I>) const
Chao Liu's avatar
Chao Liu committed
62
    {
Chao Liu's avatar
Chao Liu committed
63
        return Get(Number<I>{});
64
65
    }

Chao Liu's avatar
Chao Liu committed
66
67
68
    // make sure I is constepxr if you want a constexpr return type
    __host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }

69
    template <index_t... IRs>
70
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
71
    {
Chao Liu's avatar
Chao Liu committed
72
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
73
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
74

Chao Liu's avatar
Chao Liu committed
75
76
77
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

        return Sequence<Type::Get(Number<IRs>{})...>{};
78
79
    }

Chao Liu's avatar
Chao Liu committed
80
81
82
83
84
85
86
87
88
89
90
91
    // MapOld2New is Sequence<...>
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
        static_assert(MapOld2New::GetSize() == GetSize(),
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

92
93
94
95
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
    __host__ __device__ static constexpr auto Front()
98
    {
Chao Liu's avatar
Chao Liu committed
99
100
        static_assert(mSize > 0, "wrong!");
        return Get(Number<0>{});
101
    }
102

Chao Liu's avatar
Chao Liu committed
103
    __host__ __device__ static constexpr auto Back()
104
    {
Chao Liu's avatar
Chao Liu committed
105
106
        static_assert(mSize > 0, "wrong!");
        return Get(Number<mSize - 1>{});
107
    }
108

109
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
110

111
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
112
113
114

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
115
    {
Chao Liu's avatar
Chao Liu committed
116
        return Sequence<Xs..., Is...>{};
117
118
    }

Chao Liu's avatar
Chao Liu committed
119
120
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
121
    {
Chao Liu's avatar
Chao Liu committed
122
        return Sequence<Xs..., Is...>{};
123
124
    }

Chao Liu's avatar
Chao Liu committed
125
126
127
128
129
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
130

Chao Liu's avatar
Chao Liu committed
131
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
132
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
133
    {
Chao Liu's avatar
Chao Liu committed
134
135
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
136

Chao Liu's avatar
Chao Liu committed
137
    template <index_t... Ns>
138
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
139
    {
Chao Liu's avatar
Chao Liu committed
140
        return Sequence<Type::Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
141
    }
Chao Liu's avatar
Chao Liu committed
142

Chao Liu's avatar
Chao Liu committed
143
    template <index_t... Ns>
144
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
145
    {
Chao Liu's avatar
Chao Liu committed
146
        return Sequence<Type::Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
147
    }
148
149

    template <index_t I, index_t X>
150
151
152
153
154
155
156
157
158
159
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
        static_assert(I < GetSize(), "wrong!");

        using seq_split          = sequence_split<Type, I>;
        constexpr auto seq_left  = typename seq_split::SeqType0{};
        constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
160
161
162
163
164
165

    template <class F>
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
166
167
};

Chao Liu's avatar
Chao Liu committed
168
// merge sequence
Chao Liu's avatar
Chao Liu committed
169
170
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
171

Chao Liu's avatar
Chao Liu committed
172
173
174
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
175
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
176
};
Chao Liu's avatar
Chao Liu committed
177

Chao Liu's avatar
Chao Liu committed
178
179
180
// generate sequence
template <index_t IBegin, index_t NRemain, class F>
struct sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
181
{
Chao Liu's avatar
Chao Liu committed
182
183
184
    static constexpr index_t NRemainLeft  = NRemain / 2;
    static constexpr index_t NRemainRight = NRemain - NRemainLeft;
    static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
185

Chao Liu's avatar
Chao Liu committed
186
187
188
    using type =
        typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
                                typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
Chao Liu's avatar
Chao Liu committed
189
190
};

Chao Liu's avatar
Chao Liu committed
191
192
template <index_t I, class F>
struct sequence_gen_impl<I, 1, F>
Chao Liu's avatar
Chao Liu committed
193
{
Chao Liu's avatar
Chao Liu committed
194
195
    static constexpr index_t Is = F{}(Number<I>{});
    using type                  = Sequence<Is>;
Chao Liu's avatar
Chao Liu committed
196
};
Chao Liu's avatar
Chao Liu committed
197

Chao Liu's avatar
Chao Liu committed
198
199
template <index_t I, class F>
struct sequence_gen_impl<I, 0, F>
Chao Liu's avatar
Chao Liu committed
200
{
Chao Liu's avatar
Chao Liu committed
201
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
202
203
};

Chao Liu's avatar
Chao Liu committed
204
205
206
207
208
209
210
template <index_t NSize, class F>
struct sequence_gen
{
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
211
template <index_t IBegin, index_t IEnd, index_t Increment>
212
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
213
{
Chao Liu's avatar
Chao Liu committed
214
215
216
217
218
219
220
221
222
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

    using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
Chao Liu's avatar
Chao Liu committed
223
224
225
226
227
228
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
229
    struct F
Chao Liu's avatar
Chao Liu committed
230
231
232
233
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
234
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
235
236
237
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
238
template <class, class, index_t>
Chao Liu's avatar
Chao Liu committed
239
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
240

Chao Liu's avatar
Chao Liu committed
241
242
template <index_t I, index_t... Is, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
243
{
Chao Liu's avatar
Chao Liu committed
244
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
245
246
247

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
248
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
249
250
};

Chao Liu's avatar
Chao Liu committed
251
252
template <index_t I, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
253
{
Chao Liu's avatar
Chao Liu committed
254
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
255
256
};

Chao Liu's avatar
Chao Liu committed
257
258
template <class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
259
{
Chao Liu's avatar
Chao Liu committed
260
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
261
262
};

Chao Liu's avatar
Chao Liu committed
263
// split sequence
Chao Liu's avatar
Chao Liu committed
264
265
266
267
268
template <class Seq, index_t I>
struct sequence_split
{
    static constexpr index_t NSize = Seq{}.GetSize();

Chao Liu's avatar
Chao Liu committed
269
270
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
271

Chao Liu's avatar
Chao Liu committed
272
273
    using SeqType0 = decltype(Seq::Extract(range0{}));
    using SeqType1 = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
274
275
};

Chao Liu's avatar
Chao Liu committed
276
// reverse sequence
Chao Liu's avatar
Chao Liu committed
277
278
279
280
281
282
template <class Seq>
struct sequence_reverse
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
283
284
285
    using type      = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::type,
        typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
Chao Liu's avatar
Chao Liu committed
286
287
288
289
290
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
291
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
292
293
294
295
296
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
297
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
298
};
Chao Liu's avatar
Chao Liu committed
299

Chao Liu's avatar
Chao Liu committed
300
301
302
303
304
305
306
307
308
309
310
311
template <class Seq, class Compare>
struct sequence_sort
{
    // not implemented
};

template <class Seq, class Compare>
struct sequence_unique_sort
{
    // not implemented
};

Chao Liu's avatar
Chao Liu committed
312
313
314
template <class Seq>
struct is_valid_sequence_map
{
Chao Liu's avatar
Chao Liu committed
315
    // not implemented yet, always return true
Chao Liu's avatar
Chao Liu committed
316
    static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
Chao Liu's avatar
Chao Liu committed
317
318
319

    // TODO: add proper check for is_valid, something like:
    // static constexpr bool value =
Chao Liu's avatar
Chao Liu committed
320
    //     is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
Chao Liu's avatar
Chao Liu committed
321
    //             typename sequence_sort<Seq>::SortedSeqType>{};
Chao Liu's avatar
Chao Liu committed
322
};
323

Chao Liu's avatar
Chao Liu committed
324
325
326
327
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
    private:
328
329
    static constexpr auto new_y2x =
        WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    public:
    using type =
        typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};

template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
    using type = WorkingY2X;
};

template <class X2Y>
struct sequence_map_inverse
{
    using type =
        typename sequence_map_inverse_impl<X2Y,
                                           typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
                                           0,
                                           X2Y::GetSize()>::type;
};

Chao Liu's avatar
Chao Liu committed
352
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
353
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
354
355
356
357
358
359
360
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
361
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
362
363
364
365
366
367
368
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
369
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
370
371
372
373
374
375
376
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
377
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
378
379
380
381
382
383
384
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
385
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
386
387
388
389
390
391
392
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
393
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
394
{
Chao Liu's avatar
Chao Liu committed
395
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
396
397
398
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
399
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
400
{
Chao Liu's avatar
Chao Liu committed
401
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
402
403
404
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
405
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
406
{
Chao Liu's avatar
Chao Liu committed
407
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
408
409
410
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
411
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
412
{
Chao Liu's avatar
Chao Liu committed
413
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
414
415
416
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
417
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
418
{
Chao Liu's avatar
Chao Liu committed
419
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
420
421
}

Chao Liu's avatar
Chao Liu committed
422
423
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
424
{
Chao Liu's avatar
Chao Liu committed
425
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
426
427
}

Chao Liu's avatar
Chao Liu committed
428
429
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
430
{
Chao Liu's avatar
Chao Liu committed
431
432
433
    constexpr auto seq_x = Sequence<Xs...>{};

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
434
435
}

Chao Liu's avatar
Chao Liu committed
436
437
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
438
{
Chao Liu's avatar
Chao Liu committed
439
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
440
441
}

Chao Liu's avatar
Chao Liu committed
442
443
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
444
{
Chao Liu's avatar
Chao Liu committed
445
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
446
447
}

Chao Liu's avatar
Chao Liu committed
448
449
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
450
{
Chao Liu's avatar
Chao Liu committed
451
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
452
453
}

454
455
456
457
458
459
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
460
461
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
462
{
463
464
    static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
    return sequence_pop_front(Seq::Reverse()).Reverse();
465
}
466

Chao Liu's avatar
Chao Liu committed
467
468
469
470
471
472
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
473
template <class F, index_t... Xs, index_t... Ys>
474
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
475
{
476
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
477
478
479
480

    return Sequence<f(Xs, Ys)...>{};
}

481
482
483
484
485
486
487
488
489
490
491
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
492
493
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
494
{
Chao Liu's avatar
Chao Liu committed
495
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
496
497
}

Chao Liu's avatar
Chao Liu committed
498
499
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
500
{
Chao Liu's avatar
Chao Liu committed
501
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
502
}
503

Chao Liu's avatar
Chao Liu committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
template <class Seq, class Reduce>
struct lambda_accumulate_on_sequence
{
    const Reduce& f;
    index_t& result;

    __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
        : f(f_), result(result_)
    {
    }

    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return result = f(result, Seq::Get(IDim{}));
    }
};

template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
    index_t result = Init;

    static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));

    return result;
}

533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
    constexpr index_t nsize = Sequence<Xs...>::GetSize();

    static_assert(nsize <= 10, "wrong!");

    static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });

    static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });

    static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 5>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 6>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 7>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 8>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 9>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 10>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
568
569
570

} // namespace ck
#endif