sequence.hpp 27.9 KB
Newer Older
1
2
3
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP

Chao Liu's avatar
Chao Liu committed
4
#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "type.hpp"
Chao Liu's avatar
Chao Liu committed
6
#include "functional.hpp"
Chao Liu's avatar
Chao Liu committed
7
#include "math.hpp"
8

9
10
namespace ck {

Chao Liu's avatar
Chao Liu committed
11
12
13
template <index_t, index_t, index_t>
struct static_for;

14
15
16
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
17
template <typename Seq, index_t I>
18
19
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
20
template <typename>
21
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
22

Chao Liu's avatar
Chao Liu committed
23
template <typename>
Chao Liu's avatar
Chao Liu committed
24
25
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
26
template <typename>
27
28
29
30
31
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
32
template <typename Seq>
33
34
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
35
template <index_t... Is>
36
37
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
38
39
    using Type      = Sequence;
    using data_type = index_t;
40

41
    static constexpr index_t mSize = sizeof...(Is);
42

Chao Liu's avatar
Chao Liu committed
43
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
44

Chao Liu's avatar
Chao Liu committed
45
46
47
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
48
    {
Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
54
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
55
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
56
    {
Chao Liu's avatar
Chao Liu committed
57
58
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
59
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
60
61
    }

Chao Liu's avatar
Chao Liu committed
62
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
63
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
64
    {
Chao Liu's avatar
Chao Liu committed
65
        return At(Number<I>{});
66
67
    }

Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
73

74
    template <index_t... IRs>
75
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
76
    {
Chao Liu's avatar
Chao Liu committed
77
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
78
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
79

Chao Liu's avatar
Chao Liu committed
80
81
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
82
        return Sequence<Type::At(Number<IRs>{})...>{};
83
84
    }

Chao Liu's avatar
Chao Liu committed
85
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
86
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
87
88
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
89
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
90
91
92
93
94
95
96
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

97
98
99
100
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
101

Chao Liu's avatar
Chao Liu committed
102
    __host__ __device__ static constexpr auto Front()
103
    {
Chao Liu's avatar
Chao Liu committed
104
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
105
        return At(Number<0>{});
106
    }
107

Chao Liu's avatar
Chao Liu committed
108
    __host__ __device__ static constexpr auto Back()
109
    {
Chao Liu's avatar
Chao Liu committed
110
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
111
        return At(Number<mSize - 1>{});
112
    }
113

114
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
115

116
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
117
118
119

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
120
    {
Chao Liu's avatar
Chao Liu committed
121
        return Sequence<Xs..., Is...>{};
122
123
    }

Chao Liu's avatar
Chao Liu committed
124
125
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
126
    {
Chao Liu's avatar
Chao Liu committed
127
        return Sequence<Xs..., Is...>{};
128
129
    }

Chao Liu's avatar
Chao Liu committed
130
131
132
133
134
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
135

Chao Liu's avatar
Chao Liu committed
136
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
137
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
138
    {
Chao Liu's avatar
Chao Liu committed
139
140
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    template <index_t... Ns>
143
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
144
    {
Chao Liu's avatar
Chao Liu committed
145
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
146
    }
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
    template <index_t... Ns>
149
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
150
    {
Chao Liu's avatar
Chao Liu committed
151
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
152
    }
153
154

    template <index_t I, index_t X>
155
156
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
157
        static_assert(I < Size(), "wrong!");
158
159

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
160
161
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
162
163
164

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
165

Chao Liu's avatar
Chao Liu committed
166
    template <typename F>
Chao Liu's avatar
Chao Liu committed
167
168
169
170
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
171
172
173
174
175
176
177
178

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
179
180
};

Chao Liu's avatar
Chao Liu committed
181
// merge sequence
Chao Liu's avatar
Chao Liu committed
182
183
184
185
186
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
187

Chao Liu's avatar
Chao Liu committed
188
189
190
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
191
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
192
};
Chao Liu's avatar
Chao Liu committed
193

Chao Liu's avatar
Chao Liu committed
194
195
196
197
198
199
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
200
// generate sequence
Chao Liu's avatar
Chao Liu committed
201
202
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
203
{
Chao Liu's avatar
Chao Liu committed
204
205
206
207
208
209
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
210

Chao Liu's avatar
Chao Liu committed
211
212
213
214
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
215

Chao Liu's avatar
Chao Liu committed
216
217
218
219
220
221
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
222

Chao Liu's avatar
Chao Liu committed
223
224
225
226
227
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
228

Chao Liu's avatar
Chao Liu committed
229
230
231
232
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
233
template <index_t IBegin, index_t IEnd, index_t Increment>
234
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
235
{
Chao Liu's avatar
Chao Liu committed
236
237
238
239
240
241
242
243
244
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

    using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
Chao Liu's avatar
Chao Liu committed
245
246
247
248
249
250
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
251
    struct F
Chao Liu's avatar
Chao Liu committed
252
253
254
255
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
256
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
257
258
259
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
260
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
261
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
262

Chao Liu's avatar
Chao Liu committed
263
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
264
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
265
{
Chao Liu's avatar
Chao Liu committed
266
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
267
268
269

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
270
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
271
272
};

Chao Liu's avatar
Chao Liu committed
273
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
274
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
275
{
Chao Liu's avatar
Chao Liu committed
276
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
277
278
};

Chao Liu's avatar
Chao Liu committed
279
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
280
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
281
{
Chao Liu's avatar
Chao Liu committed
282
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
283
284
};

Chao Liu's avatar
Chao Liu committed
285
// split sequence
Chao Liu's avatar
Chao Liu committed
286
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
287
288
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
289
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
290

Chao Liu's avatar
Chao Liu committed
291
292
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
293

Chao Liu's avatar
Chao Liu committed
294
295
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
296
297
};

Chao Liu's avatar
Chao Liu committed
298
// reverse sequence
Chao Liu's avatar
Chao Liu committed
299
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
300
301
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
302
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
303
304

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
305
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
306
307
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
308
309
310
311
312
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
313
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
314
315
316
317
318
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
319
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
320
};
Chao Liu's avatar
Chao Liu committed
321

Chao Liu's avatar
Chao Liu committed
322
#if 1
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
344
345
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
346
{
Chao Liu's avatar
Chao Liu committed
347
348
349
350
351
352
353
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
354
355
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
385
386
    };

Chao Liu's avatar
Chao Liu committed
387
388
389
390
391
392
393
394
395
396
397
398
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
399
    {
Chao Liu's avatar
Chao Liu committed
400
401
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
402
403
    };

Chao Liu's avatar
Chao Liu committed
404
405
406
407
408
409
410
411
412
413
414
415
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
416
    {
Chao Liu's avatar
Chao Liu committed
417
418
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
419
420
    };

Chao Liu's avatar
Chao Liu committed
421
422
423
424
425
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
426
427
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
428
429
430
431
432
433
434
435
436
437
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
438
439
    };

Chao Liu's avatar
Chao Liu committed
440
441
442
443
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
444

Chao Liu's avatar
Chao Liu committed
445
446
447
448
449
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
450

Chao Liu's avatar
Chao Liu committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
465
466
};

Chao Liu's avatar
Chao Liu committed
467
468
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
469
{
Chao Liu's avatar
Chao Liu committed
470
471
472
473
474
475
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
476

Chao Liu's avatar
Chao Liu committed
477
478
479
480
481
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
482
483
};

484
485
486
487
488
489
490
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
491
492
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
493
{
Chao Liu's avatar
Chao Liu committed
494
495
496
497
498
499
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
500
501
};

Chao Liu's avatar
Chao Liu committed
502
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
503
504
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
505
506
507
508
509
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
510
511
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
512
513
514
515
516
517
518
519
520
521
522
523
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
524

Chao Liu's avatar
Chao Liu committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
539
540
    };

Chao Liu's avatar
Chao Liu committed
541
542
543
544
545
546
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
547
    {
Chao Liu's avatar
Chao Liu committed
548
549
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
550
551
    };

Chao Liu's avatar
Chao Liu committed
552
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
553
554
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
555
556
557
558
559
560
561
562
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
563
564
    };

Chao Liu's avatar
Chao Liu committed
565
566
567
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
568

Chao Liu's avatar
Chao Liu committed
569
570
571
572
573
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
574
575
};

Chao Liu's avatar
Chao Liu committed
576
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
577
578
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
579
580
{
};
581

Chao Liu's avatar
Chao Liu committed
582
583
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
584
{
Chao Liu's avatar
Chao Liu committed
585
586
587
588
589
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
590

Chao Liu's avatar
Chao Liu committed
591
592
593
594
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
595

Chao Liu's avatar
Chao Liu committed
596
597
598
599
600
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
601
602

    using type =
Chao Liu's avatar
Chao Liu committed
603
604
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
605
                                           0,
Chao Liu's avatar
Chao Liu committed
606
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
607
608
};

609
610
611
612
613
614
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
615
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
616
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
617
618
619
620
621
622
623
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
624
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
625
626
627
628
629
630
631
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
632
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
633
634
635
636
637
638
639
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
640
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
641
642
643
644
645
646
647
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
648
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
649
650
651
652
653
654
655
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
656
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
657
{
Chao Liu's avatar
Chao Liu committed
658
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
659
660
661
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
662
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
663
{
Chao Liu's avatar
Chao Liu committed
664
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
665
666
667
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
668
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
669
{
Chao Liu's avatar
Chao Liu committed
670
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
671
672
673
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
674
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
675
{
Chao Liu's avatar
Chao Liu committed
676
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
677
678
679
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
680
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
681
{
Chao Liu's avatar
Chao Liu committed
682
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
683
684
}

Chao Liu's avatar
Chao Liu committed
685
686
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
687
{
Chao Liu's avatar
Chao Liu committed
688
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
689
690
}

Chao Liu's avatar
Chao Liu committed
691
692
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
693
{
Chao Liu's avatar
Chao Liu committed
694
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
695
696
}

Chao Liu's avatar
Chao Liu committed
697
698
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
699
{
Chao Liu's avatar
Chao Liu committed
700
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
701
702
}

Chao Liu's avatar
Chao Liu committed
703
704
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
705
{
Chao Liu's avatar
Chao Liu committed
706
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
707
708
}

Chao Liu's avatar
Chao Liu committed
709
710
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
711
{
Chao Liu's avatar
Chao Liu committed
712
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
713
714
}

715
716
717
718
719
720
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
721
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
722
__host__ __device__ constexpr auto sequence_pop_back(Seq)
723
{
Chao Liu's avatar
Chao Liu committed
724
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
725
    return sequence_pop_front(Seq::Reverse()).Reverse();
726
}
727

Chao Liu's avatar
Chao Liu committed
728
729
730
731
732
733
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

734
735
736
737
738
739
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
740
template <typename F, index_t... Xs, index_t... Ys>
741
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
742
{
743
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
744
745
746
747

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
748
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
749
750
751
752
753
754
755
756
757
758
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
759
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
760
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
761
{
Chao Liu's avatar
Chao Liu committed
762
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
763
764
}

Chao Liu's avatar
Chao Liu committed
765
766
767
768
769
770
771
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
772
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
773
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
774
{
Chao Liu's avatar
Chao Liu committed
775
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
776
}
777

Chao Liu's avatar
Chao Liu committed
778
template <typename Seq, index_t... Is>
779
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
780
781
782
783
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

807
808
809
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
810
811
812
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
813
814
}

Chao Liu's avatar
Chao Liu committed
815
816
817
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
818
{
Chao Liu's avatar
Chao Liu committed
819
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
820

Chao Liu's avatar
Chao Liu committed
821
822
823
824
825
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
826

Chao Liu's avatar
Chao Liu committed
827
828
829
830
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
831
};
Chao Liu's avatar
Chao Liu committed
832
833
834
835
836
837
838
839
840
841
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
842

Chao Liu's avatar
Chao Liu committed
843
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
844
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
845
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
846
847
848
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
849
850
851
852
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
853
854
855
856

    return result;
}

Chao Liu's avatar
Chao Liu committed
857
858
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
859
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
860
861
862
863
864
865
866
867
868
869
870
871
872
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
873
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
874
875
876
877
878
879
880
881
882
883
884
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

885
886
} // namespace ck
#endif