sequence.hpp 28.3 KB
Newer Older
1
#pragma once
2

Chao Liu's avatar
Chao Liu committed
3
#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
4
#include "type.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "functional.hpp"
Chao Liu's avatar
Chao Liu committed
6
#include "math.hpp"
7

8
9
namespace ck {

Chao Liu's avatar
Chao Liu committed
10
11
12
template <index_t, index_t, index_t>
struct static_for;

13
14
15
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
16
template <typename Seq, index_t I>
17
18
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
19
template <typename>
20
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
21

Chao Liu's avatar
Chao Liu committed
22
template <typename>
Chao Liu's avatar
Chao Liu committed
23
24
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
25
template <typename>
26
27
28
29
30
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
31
template <typename Seq>
32
33
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
34
template <index_t... Is>
35
36
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
37
38
    using Type      = Sequence;
    using data_type = index_t;
39

40
    static constexpr index_t mSize = sizeof...(Is);
41

Chao Liu's avatar
Chao Liu committed
42
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
43

Chao Liu's avatar
Chao Liu committed
44
45
46
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
47
    {
Chao Liu's avatar
Chao Liu committed
48
49
50
51
52
53
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
54
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
55
    {
Chao Liu's avatar
Chao Liu committed
56
57
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
58
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
59
60
    }

Chao Liu's avatar
Chao Liu committed
61
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
62
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
63
    {
Chao Liu's avatar
Chao Liu committed
64
        return At(Number<I>{});
65
66
    }

Chao Liu's avatar
Chao Liu committed
67
68
69
70
71
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
72

73
    template <index_t... IRs>
74
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
75
    {
Chao Liu's avatar
Chao Liu committed
76
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
77
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
78

Chao Liu's avatar
Chao Liu committed
79
80
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
81
        return Sequence<Type::At(Number<IRs>{})...>{};
82
83
    }

Chao Liu's avatar
Chao Liu committed
84
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
85
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
86
87
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
88
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
89
90
91
92
93
94
95
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

96
97
98
99
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
100

Chao Liu's avatar
Chao Liu committed
101
    __host__ __device__ static constexpr auto Front()
102
    {
Chao Liu's avatar
Chao Liu committed
103
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
104
        return At(Number<0>{});
105
    }
106

Chao Liu's avatar
Chao Liu committed
107
    __host__ __device__ static constexpr auto Back()
108
    {
Chao Liu's avatar
Chao Liu committed
109
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
110
        return At(Number<mSize - 1>{});
111
    }
112

113
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
114

115
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
116
117
118

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
119
    {
Chao Liu's avatar
Chao Liu committed
120
        return Sequence<Xs..., Is...>{};
121
122
    }

Chao Liu's avatar
Chao Liu committed
123
124
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
125
    {
Chao Liu's avatar
Chao Liu committed
126
        return Sequence<Xs..., Is...>{};
127
128
    }

Chao Liu's avatar
Chao Liu committed
129
130
131
132
133
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
134

Chao Liu's avatar
Chao Liu committed
135
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
136
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
137
    {
Chao Liu's avatar
Chao Liu committed
138
139
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
140

Chao Liu's avatar
Chao Liu committed
141
    template <index_t... Ns>
142
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
143
    {
Chao Liu's avatar
Chao Liu committed
144
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
145
    }
Chao Liu's avatar
Chao Liu committed
146

Chao Liu's avatar
Chao Liu committed
147
    template <index_t... Ns>
148
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
149
    {
Chao Liu's avatar
Chao Liu committed
150
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
151
    }
152
153

    template <index_t I, index_t X>
154
155
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
156
        static_assert(I < Size(), "wrong!");
157
158

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
159
160
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
161
162
163

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
164

Chao Liu's avatar
Chao Liu committed
165
    template <typename F>
Chao Liu's avatar
Chao Liu committed
166
167
168
169
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
170
171
172
173
174
175
176
177

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
178
179
};

Chao Liu's avatar
Chao Liu committed
180
// merge sequence
Chao Liu's avatar
Chao Liu committed
181
182
183
184
185
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
186

Chao Liu's avatar
Chao Liu committed
187
188
189
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
190
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
191
};
Chao Liu's avatar
Chao Liu committed
192

Chao Liu's avatar
Chao Liu committed
193
194
195
196
197
198
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
199
// generate sequence
Chao Liu's avatar
Chao Liu committed
200
201
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
202
{
Chao Liu's avatar
Chao Liu committed
203
204
205
206
207
208
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
209

Chao Liu's avatar
Chao Liu committed
210
211
212
213
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
214

Chao Liu's avatar
Chao Liu committed
215
216
217
218
219
220
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
221

Chao Liu's avatar
Chao Liu committed
222
223
224
225
226
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
227

Chao Liu's avatar
Chao Liu committed
228
229
230
231
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
232
template <index_t IBegin, index_t IEnd, index_t Increment>
233
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
234
{
Chao Liu's avatar
Chao Liu committed
235
236
237
238
239
240
241
242
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

243
244
245
246
247
248
249
    using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
    using type1 = Sequence<>;

    static constexpr bool kHasContent =
        (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);

    using type = typename conditional<kHasContent, type0, type1>::type;
Chao Liu's avatar
Chao Liu committed
250
251
252
253
254
255
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
256
    struct F
Chao Liu's avatar
Chao Liu committed
257
258
259
260
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
261
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
262
263
264
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
265
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
266
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
267

Chao Liu's avatar
Chao Liu committed
268
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
269
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
270
{
Chao Liu's avatar
Chao Liu committed
271
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
272
273
274

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
275
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
276
277
};

Chao Liu's avatar
Chao Liu committed
278
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
279
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
280
{
Chao Liu's avatar
Chao Liu committed
281
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
282
283
};

Chao Liu's avatar
Chao Liu committed
284
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
285
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
286
{
Chao Liu's avatar
Chao Liu committed
287
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
288
289
};

Chao Liu's avatar
Chao Liu committed
290
// split sequence
Chao Liu's avatar
Chao Liu committed
291
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
292
293
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
294
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
295

Chao Liu's avatar
Chao Liu committed
296
297
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
298

Chao Liu's avatar
Chao Liu committed
299
300
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
301
302
};

Chao Liu's avatar
Chao Liu committed
303
// reverse sequence
Chao Liu's avatar
Chao Liu committed
304
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
305
306
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
307
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
308
309

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
310
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
311
312
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
313
314
315
316
317
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
318
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
319
320
321
322
323
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
324
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
325
};
Chao Liu's avatar
Chao Liu committed
326

Chao Liu's avatar
Chao Liu committed
327
#if 1
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
349
350
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
351
{
Chao Liu's avatar
Chao Liu committed
352
353
354
355
356
357
358
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
359
360
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
390
391
    };

Chao Liu's avatar
Chao Liu committed
392
393
394
395
396
397
398
399
400
401
402
403
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
404
    {
Chao Liu's avatar
Chao Liu committed
405
406
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
407
408
    };

Chao Liu's avatar
Chao Liu committed
409
410
411
412
413
414
415
416
417
418
419
420
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
421
    {
Chao Liu's avatar
Chao Liu committed
422
423
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
424
425
    };

Chao Liu's avatar
Chao Liu committed
426
427
428
429
430
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
431
432
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
433
434
435
436
437
438
439
440
441
442
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
443
444
    };

Chao Liu's avatar
Chao Liu committed
445
446
447
448
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
449

Chao Liu's avatar
Chao Liu committed
450
451
452
453
454
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
455

Chao Liu's avatar
Chao Liu committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
470
471
};

Chao Liu's avatar
Chao Liu committed
472
473
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
474
{
Chao Liu's avatar
Chao Liu committed
475
476
477
478
479
480
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
481

Chao Liu's avatar
Chao Liu committed
482
483
484
485
486
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
487
488
};

489
490
491
492
493
494
495
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
496
497
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
498
{
Chao Liu's avatar
Chao Liu committed
499
500
501
502
503
504
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
505
506
};

Chao Liu's avatar
Chao Liu committed
507
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
508
509
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
510
511
512
513
514
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
515
516
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
517
518
519
520
521
522
523
524
525
526
527
528
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
529

Chao Liu's avatar
Chao Liu committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
544
545
    };

Chao Liu's avatar
Chao Liu committed
546
547
548
549
550
551
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
552
    {
Chao Liu's avatar
Chao Liu committed
553
554
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
555
556
    };

Chao Liu's avatar
Chao Liu committed
557
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
558
559
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
560
561
562
563
564
565
566
567
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
568
569
    };

Chao Liu's avatar
Chao Liu committed
570
571
572
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
573

Chao Liu's avatar
Chao Liu committed
574
575
576
577
578
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
579
580
};

Chao Liu's avatar
Chao Liu committed
581
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
582
583
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
584
585
{
};
586

Chao Liu's avatar
Chao Liu committed
587
588
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
589
{
Chao Liu's avatar
Chao Liu committed
590
591
592
593
594
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
595

Chao Liu's avatar
Chao Liu committed
596
597
598
599
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
600

Chao Liu's avatar
Chao Liu committed
601
602
603
604
605
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
606
607

    using type =
Chao Liu's avatar
Chao Liu committed
608
609
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
610
                                           0,
Chao Liu's avatar
Chao Liu committed
611
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
612
613
};

614
615
616
617
618
619
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
620
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
621
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
622
623
624
625
626
627
628
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
629
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
630
631
632
633
634
635
636
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
637
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
638
639
640
641
642
643
644
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
645
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
646
647
648
649
650
651
652
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
653
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
654
655
656
657
658
659
660
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
661
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
662
{
Chao Liu's avatar
Chao Liu committed
663
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
664
665
666
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
667
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
668
{
Chao Liu's avatar
Chao Liu committed
669
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
670
671
672
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
673
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
674
{
Chao Liu's avatar
Chao Liu committed
675
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
676
677
678
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
679
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
680
{
Chao Liu's avatar
Chao Liu committed
681
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
682
683
684
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
685
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
686
{
Chao Liu's avatar
Chao Liu committed
687
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
688
689
}

Chao Liu's avatar
Chao Liu committed
690
691
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
692
{
Chao Liu's avatar
Chao Liu committed
693
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
694
695
}

Chao Liu's avatar
Chao Liu committed
696
697
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
698
{
Chao Liu's avatar
Chao Liu committed
699
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
700
701
}

Chao Liu's avatar
Chao Liu committed
702
703
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
704
{
Chao Liu's avatar
Chao Liu committed
705
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
706
707
}

Chao Liu's avatar
Chao Liu committed
708
709
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
710
{
Chao Liu's avatar
Chao Liu committed
711
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
712
713
}

Chao Liu's avatar
Chao Liu committed
714
715
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
716
{
Chao Liu's avatar
Chao Liu committed
717
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
718
719
}

720
721
722
723
724
725
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
726
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
727
__host__ __device__ constexpr auto sequence_pop_back(Seq)
728
{
Chao Liu's avatar
Chao Liu committed
729
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
730
    return sequence_pop_front(Seq::Reverse()).Reverse();
731
}
732

Chao Liu's avatar
Chao Liu committed
733
734
735
736
737
738
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

739
740
741
742
743
744
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
745
template <typename F, index_t... Xs, index_t... Ys>
746
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
747
{
748
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
749
750
751
752

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
753
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
754
755
756
757
758
759
760
761
762
763
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
764
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
765
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
766
{
Chao Liu's avatar
Chao Liu committed
767
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
768
769
}

Chao Liu's avatar
Chao Liu committed
770
771
772
773
774
775
776
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
777
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
778
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
779
{
Chao Liu's avatar
Chao Liu committed
780
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
781
}
782

Chao Liu's avatar
Chao Liu committed
783
template <typename Seq, index_t... Is>
784
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
785
786
787
788
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

812
813
814
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
815
816
817
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
818
819
}

Chao Liu's avatar
Chao Liu committed
820
821
822
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
823
{
Chao Liu's avatar
Chao Liu committed
824
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
825

Chao Liu's avatar
Chao Liu committed
826
827
828
829
830
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
831

Chao Liu's avatar
Chao Liu committed
832
833
834
835
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
836
};
Chao Liu's avatar
Chao Liu committed
837
838
839
840
841
842
843
844
845
846
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
847

Chao Liu's avatar
Chao Liu committed
848
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
849
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
850
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
851
852
853
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
854
855
856
857
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
858
859
860
861

    return result;
}

Chao Liu's avatar
Chao Liu committed
862
863
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
864
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
865
866
867
868
869
870
871
872
873
874
875
876
877
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
878
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
879
880
881
882
883
884
885
886
887
888
889
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

890
891
892
893
894
895
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;

template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;

896
} // namespace ck