sequence.hpp 24.5 KB
Newer Older
1
2
3
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP

Chao Liu's avatar
Chao Liu committed
4
#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "type.hpp"
Chao Liu's avatar
Chao Liu committed
6
#include "functional.hpp"
Chao Liu's avatar
Chao Liu committed
7
#include "math.hpp"
8

9
10
namespace ck {

Chao Liu's avatar
Chao Liu committed
11
12
13
template <index_t, index_t, index_t>
struct static_for;

14
15
16
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
17
template <typename Seq, index_t I>
18
19
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
20
template <typename>
21
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
22

Chao Liu's avatar
Chao Liu committed
23
template <typename>
Chao Liu's avatar
Chao Liu committed
24
25
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
26
template <typename>
27
28
29
30
31
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
32
template <typename Seq>
33
34
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
35
template <index_t... Is>
36
37
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
38
39
    using Type      = Sequence;
    using data_type = index_t;
40

41
    static constexpr index_t mSize = sizeof...(Is);
42

Chao Liu's avatar
Chao Liu committed
43
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
44

Chao Liu's avatar
Chao Liu committed
45
46
47
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
48
    {
Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
54
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
55
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
56
    {
Chao Liu's avatar
Chao Liu committed
57
58
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
59
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
60
61
    }

Chao Liu's avatar
Chao Liu committed
62
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
63
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
64
    {
Chao Liu's avatar
Chao Liu committed
65
        return At(Number<I>{});
66
67
    }

Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
73

74
    template <index_t... IRs>
75
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
76
    {
Chao Liu's avatar
Chao Liu committed
77
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
78
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
79

Chao Liu's avatar
Chao Liu committed
80
81
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
82
        return Sequence<Type::At(Number<IRs>{})...>{};
83
84
    }

Chao Liu's avatar
Chao Liu committed
85
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
86
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
87
88
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
89
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
90
91
92
93
94
95
96
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

97
98
99
100
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
101

Chao Liu's avatar
Chao Liu committed
102
    __host__ __device__ static constexpr auto Front()
103
    {
Chao Liu's avatar
Chao Liu committed
104
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
105
        return At(Number<0>{});
106
    }
107

Chao Liu's avatar
Chao Liu committed
108
    __host__ __device__ static constexpr auto Back()
109
    {
Chao Liu's avatar
Chao Liu committed
110
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
111
        return At(Number<mSize - 1>{});
112
    }
113

114
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
115

116
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
117
118
119

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
120
    {
Chao Liu's avatar
Chao Liu committed
121
        return Sequence<Xs..., Is...>{};
122
123
    }

Chao Liu's avatar
Chao Liu committed
124
125
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
126
    {
Chao Liu's avatar
Chao Liu committed
127
        return Sequence<Xs..., Is...>{};
128
129
    }

Chao Liu's avatar
Chao Liu committed
130
131
132
133
134
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
135

Chao Liu's avatar
Chao Liu committed
136
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
137
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
138
    {
Chao Liu's avatar
Chao Liu committed
139
140
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    template <index_t... Ns>
143
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
144
    {
Chao Liu's avatar
Chao Liu committed
145
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
146
    }
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
    template <index_t... Ns>
149
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
150
    {
Chao Liu's avatar
Chao Liu committed
151
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
152
    }
153
154

    template <index_t I, index_t X>
155
156
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
157
        static_assert(I < Size(), "wrong!");
158
159

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
160
161
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
162
163
164

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
165

Chao Liu's avatar
Chao Liu committed
166
    template <typename F>
Chao Liu's avatar
Chao Liu committed
167
168
169
170
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
171
172
};

Chao Liu's avatar
Chao Liu committed
173
// merge sequence
Chao Liu's avatar
Chao Liu committed
174
175
176
177
178
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
179

Chao Liu's avatar
Chao Liu committed
180
181
182
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
183
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
184
};
Chao Liu's avatar
Chao Liu committed
185

Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
191
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
192
// generate sequence
Chao Liu's avatar
Chao Liu committed
193
194
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
195
{
Chao Liu's avatar
Chao Liu committed
196
197
198
199
200
201
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
202

Chao Liu's avatar
Chao Liu committed
203
204
205
206
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
207

Chao Liu's avatar
Chao Liu committed
208
209
210
211
212
213
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
214

Chao Liu's avatar
Chao Liu committed
215
216
217
218
219
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
220

Chao Liu's avatar
Chao Liu committed
221
222
223
224
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
225
template <index_t IBegin, index_t IEnd, index_t Increment>
226
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
227
{
Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
233
234
235
236
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

    using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
Chao Liu's avatar
Chao Liu committed
237
238
239
240
241
242
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
243
    struct F
Chao Liu's avatar
Chao Liu committed
244
245
246
247
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
248
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
249
250
251
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
252
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
253
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
254

Chao Liu's avatar
Chao Liu committed
255
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
256
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
257
{
Chao Liu's avatar
Chao Liu committed
258
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
259
260
261

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
262
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
263
264
};

Chao Liu's avatar
Chao Liu committed
265
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
266
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
267
{
Chao Liu's avatar
Chao Liu committed
268
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
269
270
};

Chao Liu's avatar
Chao Liu committed
271
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
272
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
273
{
Chao Liu's avatar
Chao Liu committed
274
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
275
276
};

Chao Liu's avatar
Chao Liu committed
277
// split sequence
Chao Liu's avatar
Chao Liu committed
278
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
279
280
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
281
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
282

Chao Liu's avatar
Chao Liu committed
283
284
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
285

Chao Liu's avatar
Chao Liu committed
286
287
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
288
289
};

Chao Liu's avatar
Chao Liu committed
290
// reverse sequence
Chao Liu's avatar
Chao Liu committed
291
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
292
293
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
294
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
295
296

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
297
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
298
299
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
300
301
302
303
304
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
305
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
306
307
308
309
310
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
311
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
312
};
Chao Liu's avatar
Chao Liu committed
313

Chao Liu's avatar
Chao Liu committed
314
315
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
316
{
Chao Liu's avatar
Chao Liu committed
317
318
319
320
321
322
323
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
324
325
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
355
356
    };

Chao Liu's avatar
Chao Liu committed
357
358
359
360
361
362
363
364
365
366
367
368
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
369
    {
Chao Liu's avatar
Chao Liu committed
370
371
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
372
373
    };

Chao Liu's avatar
Chao Liu committed
374
375
376
377
378
379
380
381
382
383
384
385
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
386
    {
Chao Liu's avatar
Chao Liu committed
387
388
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
389
390
    };

Chao Liu's avatar
Chao Liu committed
391
392
393
394
395
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
396
397
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
398
399
400
401
402
403
404
405
406
407
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
408
409
    };

Chao Liu's avatar
Chao Liu committed
410
411
412
413
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
414

Chao Liu's avatar
Chao Liu committed
415
416
417
418
419
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
420

Chao Liu's avatar
Chao Liu committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
435
436
};

Chao Liu's avatar
Chao Liu committed
437
438
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
439
{
Chao Liu's avatar
Chao Liu committed
440
441
442
443
444
445
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
446

Chao Liu's avatar
Chao Liu committed
447
448
449
450
451
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
452
453
};

Chao Liu's avatar
Chao Liu committed
454
455
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
456
{
Chao Liu's avatar
Chao Liu committed
457
458
459
460
461
462
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
463
464
};

Chao Liu's avatar
Chao Liu committed
465
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
466
467
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
468
469
470
471
472
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
473
474
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
475
476
477
478
479
480
481
482
483
484
485
486
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
487

Chao Liu's avatar
Chao Liu committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
502
503
    };

Chao Liu's avatar
Chao Liu committed
504
505
506
507
508
509
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
510
    {
Chao Liu's avatar
Chao Liu committed
511
512
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
513
514
    };

Chao Liu's avatar
Chao Liu committed
515
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
516
517
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
518
519
520
521
522
523
524
525
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
526
527
    };

Chao Liu's avatar
Chao Liu committed
528
529
530
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
531

Chao Liu's avatar
Chao Liu committed
532
533
534
535
536
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
537
538
};

Chao Liu's avatar
Chao Liu committed
539
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
540
541
struct is_valid_sequence_map
{
Chao Liu's avatar
Chao Liu committed
542
543
544
    static constexpr bool value =
        is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                typename sequence_sort<SeqMap, math::less<index_t>>::type>{};
Chao Liu's avatar
Chao Liu committed
545
};
546

Chao Liu's avatar
Chao Liu committed
547
548
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
549
{
Chao Liu's avatar
Chao Liu committed
550
551
552
553
554
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
555

Chao Liu's avatar
Chao Liu committed
556
557
558
559
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
560

Chao Liu's avatar
Chao Liu committed
561
562
563
564
565
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
566
567

    using type =
Chao Liu's avatar
Chao Liu committed
568
569
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
570
                                           0,
Chao Liu's avatar
Chao Liu committed
571
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
572
573
};

Chao Liu's avatar
Chao Liu committed
574
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
575
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
576
577
578
579
580
581
582
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
583
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
584
585
586
587
588
589
590
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
591
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
592
593
594
595
596
597
598
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
599
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
600
601
602
603
604
605
606
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
607
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
608
609
610
611
612
613
614
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
615
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
616
{
Chao Liu's avatar
Chao Liu committed
617
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
618
619
620
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
621
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
622
{
Chao Liu's avatar
Chao Liu committed
623
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
624
625
626
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
627
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
628
{
Chao Liu's avatar
Chao Liu committed
629
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
630
631
632
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
633
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
634
{
Chao Liu's avatar
Chao Liu committed
635
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
636
637
638
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
639
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
640
{
Chao Liu's avatar
Chao Liu committed
641
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
642
643
}

Chao Liu's avatar
Chao Liu committed
644
645
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
646
{
Chao Liu's avatar
Chao Liu committed
647
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
648
649
}

Chao Liu's avatar
Chao Liu committed
650
651
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
652
{
Chao Liu's avatar
Chao Liu committed
653
654
655
    constexpr auto seq_x = Sequence<Xs...>{};

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
656
657
}

Chao Liu's avatar
Chao Liu committed
658
659
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
660
{
Chao Liu's avatar
Chao Liu committed
661
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
662
663
}

Chao Liu's avatar
Chao Liu committed
664
665
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
666
{
Chao Liu's avatar
Chao Liu committed
667
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
668
669
}

Chao Liu's avatar
Chao Liu committed
670
671
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
672
{
Chao Liu's avatar
Chao Liu committed
673
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
674
675
}

676
677
678
679
680
681
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
682
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
683
__host__ __device__ constexpr auto sequence_pop_back(Seq)
684
{
Chao Liu's avatar
Chao Liu committed
685
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
686
    return sequence_pop_front(Seq::Reverse()).Reverse();
687
}
688

Chao Liu's avatar
Chao Liu committed
689
template <typename F, index_t... Xs>
Chao Liu's avatar
Chao Liu committed
690
691
692
693
694
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
695
696
697
698
699
700
701
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

template <typename F, index_t... Xs, index_t... Ys>
702
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
703
{
704
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
705
706
707
708

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
709
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
710
711
712
713
714
715
716
717
718
719
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
720
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
721
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
722
{
Chao Liu's avatar
Chao Liu committed
723
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
724
725
}

Chao Liu's avatar
Chao Liu committed
726
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
727
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
728
{
Chao Liu's avatar
Chao Liu committed
729
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
730
}
731

Chao Liu's avatar
Chao Liu committed
732
733
734
735
736
737
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
738
template <typename Seq, typename Reduce>
Chao Liu's avatar
Chao Liu committed
739
740
741
742
743
744
745
746
747
748
struct lambda_accumulate_on_sequence
{
    const Reduce& f;
    index_t& result;

    __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
        : f(f_), result(result_)
    {
    }

Chao Liu's avatar
Chao Liu committed
749
    template <typename IDim>
Chao Liu's avatar
Chao Liu committed
750
751
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
Chao Liu's avatar
Chao Liu committed
752
        return result = f(result, Seq::At(IDim{}));
Chao Liu's avatar
Chao Liu committed
753
754
755
    }
};

Chao Liu's avatar
Chao Liu committed
756
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
757
758
759
760
761
762
763
764
765
766
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
    index_t result = Init;

    static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));

    return result;
}

767
768
} // namespace ck
#endif