Sequence.hpp 18.1 KB
Newer Older
1
2
3
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "integral_constant.hpp"
#include "functional.hpp"
6

7
8
namespace ck {

Chao Liu's avatar
Chao Liu committed
9
10
11
template <index_t, index_t, index_t>
struct static_for;

12
13
14
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
15
template <typename Seq, index_t I>
16
17
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
18
template <typename>
19
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
20

Chao Liu's avatar
Chao Liu committed
21
template <typename>
Chao Liu's avatar
Chao Liu committed
22
23
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
24
template <typename>
25
26
27
28
29
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
30
template <typename Seq>
31
32
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
33
template <index_t... Is>
34
35
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
36
37
    using Type      = Sequence;
    using data_type = index_t;
38

39
    static constexpr index_t mSize = sizeof...(Is);
40

Chao Liu's avatar
Chao Liu committed
41
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
42

Chao Liu's avatar
Chao Liu committed
43
44
45
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
46
    {
Chao Liu's avatar
Chao Liu committed
47
48
49
50
51
52
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
53
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
54
    {
Chao Liu's avatar
Chao Liu committed
55
56
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
57
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
58
59
    }

Chao Liu's avatar
Chao Liu committed
60
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
61
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
62
    {
Chao Liu's avatar
Chao Liu committed
63
        return At(Number<I>{});
64
65
    }

Chao Liu's avatar
Chao Liu committed
66
67
68
69
70
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
71

72
    template <index_t... IRs>
73
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
74
    {
Chao Liu's avatar
Chao Liu committed
75
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
76
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
77

Chao Liu's avatar
Chao Liu committed
78
79
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
80
        return Sequence<Type::At(Number<IRs>{})...>{};
81
82
    }

Chao Liu's avatar
Chao Liu committed
83
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
84
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
85
86
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
87
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
88
89
90
91
92
93
94
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

95
96
97
98
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
99

Chao Liu's avatar
Chao Liu committed
100
    __host__ __device__ static constexpr auto Front()
101
    {
Chao Liu's avatar
Chao Liu committed
102
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
103
        return At(Number<0>{});
104
    }
105

Chao Liu's avatar
Chao Liu committed
106
    __host__ __device__ static constexpr auto Back()
107
    {
Chao Liu's avatar
Chao Liu committed
108
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
109
        return At(Number<mSize - 1>{});
110
    }
111

112
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
113

114
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
115
116
117

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
118
    {
Chao Liu's avatar
Chao Liu committed
119
        return Sequence<Xs..., Is...>{};
120
121
    }

Chao Liu's avatar
Chao Liu committed
122
123
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
124
    {
Chao Liu's avatar
Chao Liu committed
125
        return Sequence<Xs..., Is...>{};
126
127
    }

Chao Liu's avatar
Chao Liu committed
128
129
130
131
132
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
133

Chao Liu's avatar
Chao Liu committed
134
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
135
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
136
    {
Chao Liu's avatar
Chao Liu committed
137
138
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
139

Chao Liu's avatar
Chao Liu committed
140
    template <index_t... Ns>
141
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
142
    {
Chao Liu's avatar
Chao Liu committed
143
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
144
    }
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
    template <index_t... Ns>
147
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
148
    {
Chao Liu's avatar
Chao Liu committed
149
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
150
    }
151
152

    template <index_t I, index_t X>
153
154
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
155
        static_assert(I < Size(), "wrong!");
156
157
158
159
160
161
162

        using seq_split          = sequence_split<Type, I>;
        constexpr auto seq_left  = typename seq_split::SeqType0{};
        constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
163

Chao Liu's avatar
Chao Liu committed
164
    template <typename F>
Chao Liu's avatar
Chao Liu committed
165
166
167
168
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
169
170
};

Chao Liu's avatar
Chao Liu committed
171
// merge sequence
Chao Liu's avatar
Chao Liu committed
172
173
174
175
176
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
177

Chao Liu's avatar
Chao Liu committed
178
179
180
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
181
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
182
};
Chao Liu's avatar
Chao Liu committed
183

Chao Liu's avatar
Chao Liu committed
184
185
186
187
188
189
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
190
// generate sequence
Chao Liu's avatar
Chao Liu committed
191
template <index_t IBegin, index_t NRemain, typename F>
Chao Liu's avatar
Chao Liu committed
192
struct sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
193
{
Chao Liu's avatar
Chao Liu committed
194
195
196
    static constexpr index_t NRemainLeft  = NRemain / 2;
    static constexpr index_t NRemainRight = NRemain - NRemainLeft;
    static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
197

Chao Liu's avatar
Chao Liu committed
198
199
200
    using type =
        typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
                                typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
Chao Liu's avatar
Chao Liu committed
201
202
};

Chao Liu's avatar
Chao Liu committed
203
template <index_t I, typename F>
Chao Liu's avatar
Chao Liu committed
204
struct sequence_gen_impl<I, 1, F>
Chao Liu's avatar
Chao Liu committed
205
{
Chao Liu's avatar
Chao Liu committed
206
207
    static constexpr index_t Is = F{}(Number<I>{});
    using type                  = Sequence<Is>;
Chao Liu's avatar
Chao Liu committed
208
};
Chao Liu's avatar
Chao Liu committed
209

Chao Liu's avatar
Chao Liu committed
210
template <index_t I, typename F>
Chao Liu's avatar
Chao Liu committed
211
struct sequence_gen_impl<I, 0, F>
Chao Liu's avatar
Chao Liu committed
212
{
Chao Liu's avatar
Chao Liu committed
213
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
214
215
};

Chao Liu's avatar
Chao Liu committed
216
template <index_t NSize, typename F>
Chao Liu's avatar
Chao Liu committed
217
218
219
220
221
222
struct sequence_gen
{
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
223
template <index_t IBegin, index_t IEnd, index_t Increment>
224
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
225
{
Chao Liu's avatar
Chao Liu committed
226
227
228
229
230
231
232
233
234
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

    using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
Chao Liu's avatar
Chao Liu committed
235
236
237
238
239
240
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
241
    struct F
Chao Liu's avatar
Chao Liu committed
242
243
244
245
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
246
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
247
248
249
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
250
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
251
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
252

Chao Liu's avatar
Chao Liu committed
253
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
254
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
255
{
Chao Liu's avatar
Chao Liu committed
256
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
257
258
259

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
260
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
261
262
};

Chao Liu's avatar
Chao Liu committed
263
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
264
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
265
{
Chao Liu's avatar
Chao Liu committed
266
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
267
268
};

Chao Liu's avatar
Chao Liu committed
269
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
270
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
271
{
Chao Liu's avatar
Chao Liu committed
272
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
273
274
};

Chao Liu's avatar
Chao Liu committed
275
// split sequence
Chao Liu's avatar
Chao Liu committed
276
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
277
278
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
279
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
280

Chao Liu's avatar
Chao Liu committed
281
282
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
283

Chao Liu's avatar
Chao Liu committed
284
285
    using SeqType0 = decltype(Seq::Extract(range0{}));
    using SeqType1 = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
286
287
};

Chao Liu's avatar
Chao Liu committed
288
// reverse sequence
Chao Liu's avatar
Chao Liu committed
289
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
290
291
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
292
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
293
294

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
295
296
297
    using type      = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::type,
        typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
Chao Liu's avatar
Chao Liu committed
298
299
300
301
302
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
303
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
304
305
306
307
308
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
309
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
310
};
Chao Liu's avatar
Chao Liu committed
311

Chao Liu's avatar
Chao Liu committed
312
template <typename Seq, typename Compare>
Chao Liu's avatar
Chao Liu committed
313
314
struct sequence_sort
{
Chao Liu's avatar
Chao Liu committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
    struct sorted_sequence_merge_impl
    {
        static constexpr bool pick_left     = SeqLeft::Front() < SeqRight::Front();
        static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();

        using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));

        using new_left_seq =
            typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
        using new_right_seq =
            typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;

        using type =
            typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
                type;
    };

    template <typename SeqLeft, typename MergedSeq, typename Comp>
    struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
    {
        using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
    };

    template <typename SeqRight, typename MergedSeq, typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
    {
        using type = typename sequence_merge<MergedSeq, SeqRight>::type;
    };

    template <typename Seq0, typename Seq1, typename Comp>
    struct sorted_sequence_merge
    {
        using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
    };

    using split          = sequence_split<Seq, Seq::Size() / 2>;
    using unsorted_left  = typename split::SeqType0;
    using unsorted_right = typename split::SeqType1;

    using sorted_left  = typename sequence_sort<unsorted_left, Compare>::type;
    using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;

    using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
};

template <index_t X, index_t Y, typename Compare>
struct sequence_sort<Sequence<X, Y>, Compare>
{
    static constexpr bool x_first = Compare{}(X, Y);

    using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
};

template <index_t X, typename Compare>
struct sequence_sort<Sequence<X>, Compare>
{
    using type = Sequence<X>;
Chao Liu's avatar
Chao Liu committed
373
374
};

Chao Liu's avatar
Chao Liu committed
375
template <typename Seq, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
376
377
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
    struct sorted_sequence_uniquify_impl
    {
        static constexpr index_t new_value = WorkInputSeq::Front();
        using new_work_input_seq           = decltype(WorkInputSeq::PopFront());

        using new_working_output_seq =
            typename conditional<new_value == WorkOutputSeq::Back(),
                                 WorkOutputSeq,
                                 decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
    };

    template <typename WorkInputSeq, typename Eq>
    struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
    {
        using type = WorkInputSeq;
    };

    template <typename SortedSeq, typename Eq>
    struct sorted_sequence_uniquify
    {
        using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
    };

    using sorted_seq = typename sequence_sort<Seq, Less>::type;

    using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
Chao Liu's avatar
Chao Liu committed
405
406
};

Chao Liu's avatar
Chao Liu committed
407
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
408
409
struct is_valid_sequence_map
{
Chao Liu's avatar
Chao Liu committed
410
    // not implemented yet, always return true
Chao Liu's avatar
Chao Liu committed
411
    static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
Chao Liu's avatar
Chao Liu committed
412
413
414

    // TODO: add proper check for is_valid, something like:
    // static constexpr bool value =
Chao Liu's avatar
Chao Liu committed
415
    //     is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
Chao Liu's avatar
Chao Liu committed
416
    //             typename sequence_sort<Seq>::SortedSeqType>{};
Chao Liu's avatar
Chao Liu committed
417
};
418

Chao Liu's avatar
Chao Liu committed
419
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
Chao Liu's avatar
Chao Liu committed
420
421
422
struct sequence_map_inverse_impl
{
    private:
Chao Liu's avatar
Chao Liu committed
423
    static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
424
425
426
427
428
429

    public:
    using type =
        typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};

Chao Liu's avatar
Chao Liu committed
430
template <typename X2Y, typename WorkingY2X, index_t XBegin>
Chao Liu's avatar
Chao Liu committed
431
432
433
434
435
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
    using type = WorkingY2X;
};

Chao Liu's avatar
Chao Liu committed
436
template <typename X2Y>
Chao Liu's avatar
Chao Liu committed
437
438
439
440
struct sequence_map_inverse
{
    using type =
        typename sequence_map_inverse_impl<X2Y,
Chao Liu's avatar
Chao Liu committed
441
                                           typename uniform_sequence_gen<X2Y::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
442
                                           0,
Chao Liu's avatar
Chao Liu committed
443
                                           X2Y::Size()>::type;
Chao Liu's avatar
Chao Liu committed
444
445
};

Chao Liu's avatar
Chao Liu committed
446
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
447
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
448
449
450
451
452
453
454
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
455
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
456
457
458
459
460
461
462
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
463
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
464
465
466
467
468
469
470
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
471
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
472
473
474
475
476
477
478
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
479
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
480
481
482
483
484
485
486
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
487
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
488
{
Chao Liu's avatar
Chao Liu committed
489
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
490
491
492
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
493
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
494
{
Chao Liu's avatar
Chao Liu committed
495
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
496
497
498
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
499
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
500
{
Chao Liu's avatar
Chao Liu committed
501
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
502
503
504
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
505
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
506
{
Chao Liu's avatar
Chao Liu committed
507
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
508
509
510
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
511
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
512
{
Chao Liu's avatar
Chao Liu committed
513
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
514
515
}

Chao Liu's avatar
Chao Liu committed
516
517
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
518
{
Chao Liu's avatar
Chao Liu committed
519
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
520
521
}

Chao Liu's avatar
Chao Liu committed
522
523
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
524
{
Chao Liu's avatar
Chao Liu committed
525
526
527
    constexpr auto seq_x = Sequence<Xs...>{};

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
528
529
}

Chao Liu's avatar
Chao Liu committed
530
531
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
532
{
Chao Liu's avatar
Chao Liu committed
533
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
534
535
}

Chao Liu's avatar
Chao Liu committed
536
537
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
538
{
Chao Liu's avatar
Chao Liu committed
539
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
540
541
}

Chao Liu's avatar
Chao Liu committed
542
543
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
544
{
Chao Liu's avatar
Chao Liu committed
545
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
546
547
}

548
549
550
551
552
553
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
554
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
555
__host__ __device__ constexpr auto sequence_pop_back(Seq)
556
{
Chao Liu's avatar
Chao Liu committed
557
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
558
    return sequence_pop_front(Seq::Reverse()).Reverse();
559
}
560

Chao Liu's avatar
Chao Liu committed
561
template <typename F, index_t... Xs>
Chao Liu's avatar
Chao Liu committed
562
563
564
565
566
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
567
568
569
570
571
572
573
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

template <typename F, index_t... Xs, index_t... Ys>
574
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
575
{
576
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
577
578
579
580

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
581
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
582
583
584
585
586
587
588
589
590
591
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
592
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
593
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
594
{
Chao Liu's avatar
Chao Liu committed
595
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
596
597
}

Chao Liu's avatar
Chao Liu committed
598
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
599
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
600
{
Chao Liu's avatar
Chao Liu committed
601
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
602
}
603

Chao Liu's avatar
Chao Liu committed
604
template <typename Seq, typename Reduce>
Chao Liu's avatar
Chao Liu committed
605
606
607
608
609
610
611
612
613
614
struct lambda_accumulate_on_sequence
{
    const Reduce& f;
    index_t& result;

    __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
        : f(f_), result(result_)
    {
    }

Chao Liu's avatar
Chao Liu committed
615
    template <typename IDim>
Chao Liu's avatar
Chao Liu committed
616
617
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
Chao Liu's avatar
Chao Liu committed
618
        return result = f(result, Seq::At(IDim{}));
Chao Liu's avatar
Chao Liu committed
619
620
621
    }
};

Chao Liu's avatar
Chao Liu committed
622
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
623
624
625
626
627
628
629
630
631
632
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
    index_t result = Init;

    static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));

    return result;
}

633
634
} // namespace ck
#endif