sequence.hpp 28.5 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
5

6
7
8
9
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
10

11
12
namespace ck {

Chao Liu's avatar
Chao Liu committed
13
14
15
template <index_t, index_t, index_t>
struct static_for;

16
17
18
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
19
template <typename Seq, index_t I>
20
21
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
22
template <typename>
23
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
24

Chao Liu's avatar
Chao Liu committed
25
template <typename>
Chao Liu's avatar
Chao Liu committed
26
27
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
28
template <typename>
29
30
31
32
33
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
34
template <typename Seq>
35
36
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
37
template <index_t... Is>
38
39
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
40
41
    using Type      = Sequence;
    using data_type = index_t;
42

43
    static constexpr index_t mSize = sizeof...(Is);
44

Chao Liu's avatar
Chao Liu committed
45
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
46

Chao Liu's avatar
Chao Liu committed
47
48
49
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
50
    {
Chao Liu's avatar
Chao Liu committed
51
52
53
54
55
56
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
57
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
58
    {
Chao Liu's avatar
Chao Liu committed
59
60
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
61
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
62
63
    }

Chao Liu's avatar
Chao Liu committed
64
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
65
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
66
    {
Chao Liu's avatar
Chao Liu committed
67
        return At(Number<I>{});
68
69
    }

Chao Liu's avatar
Chao Liu committed
70
71
72
73
74
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
75

76
    template <index_t... IRs>
77
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
78
    {
Chao Liu's avatar
Chao Liu committed
79
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
80
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
81

Chao Liu's avatar
Chao Liu committed
82
83
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
84
        return Sequence<Type::At(Number<IRs>{})...>{};
85
86
    }

Chao Liu's avatar
Chao Liu committed
87
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
88
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
89
90
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
91
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
92
93
94
95
96
97
98
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

99
100
101
102
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
    __host__ __device__ static constexpr auto Front()
105
    {
Chao Liu's avatar
Chao Liu committed
106
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
107
        return At(Number<0>{});
108
    }
109

Chao Liu's avatar
Chao Liu committed
110
    __host__ __device__ static constexpr auto Back()
111
    {
Chao Liu's avatar
Chao Liu committed
112
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
113
        return At(Number<mSize - 1>{});
114
    }
115

116
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
117

118
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
119
120
121

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
122
    {
Chao Liu's avatar
Chao Liu committed
123
        return Sequence<Xs..., Is...>{};
124
125
    }

Chao Liu's avatar
Chao Liu committed
126
127
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
128
    {
Chao Liu's avatar
Chao Liu committed
129
        return Sequence<Xs..., Is...>{};
130
131
    }

Chao Liu's avatar
Chao Liu committed
132
133
134
135
136
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
137

Chao Liu's avatar
Chao Liu committed
138
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
139
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
140
    {
Chao Liu's avatar
Chao Liu committed
141
142
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
143

Chao Liu's avatar
Chao Liu committed
144
    template <index_t... Ns>
145
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
146
    {
Chao Liu's avatar
Chao Liu committed
147
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
148
    }
Chao Liu's avatar
Chao Liu committed
149

Chao Liu's avatar
Chao Liu committed
150
    template <index_t... Ns>
151
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
152
    {
Chao Liu's avatar
Chao Liu committed
153
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
154
    }
155
156

    template <index_t I, index_t X>
157
158
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
159
        static_assert(I < Size(), "wrong!");
160
161

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
162
163
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
164
165
166

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
167

Chao Liu's avatar
Chao Liu committed
168
    template <typename F>
Chao Liu's avatar
Chao Liu committed
169
170
171
172
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
173
174
175
176
177
178
179
180

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
181
182
};

Chao Liu's avatar
Chao Liu committed
183
// merge sequence
Chao Liu's avatar
Chao Liu committed
184
185
186
187
188
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
189

Chao Liu's avatar
Chao Liu committed
190
191
192
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
193
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
194
};
Chao Liu's avatar
Chao Liu committed
195

Chao Liu's avatar
Chao Liu committed
196
197
198
199
200
201
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
202
// generate sequence
Chao Liu's avatar
Chao Liu committed
203
204
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
205
{
Chao Liu's avatar
Chao Liu committed
206
207
208
209
210
211
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
212

Chao Liu's avatar
Chao Liu committed
213
214
215
216
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
217

Chao Liu's avatar
Chao Liu committed
218
219
220
221
222
223
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
224

Chao Liu's avatar
Chao Liu committed
225
226
227
228
229
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
230

Chao Liu's avatar
Chao Liu committed
231
232
233
234
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
235
template <index_t IBegin, index_t IEnd, index_t Increment>
236
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
237
{
Chao Liu's avatar
Chao Liu committed
238
239
240
241
242
243
244
245
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

246
247
248
249
250
251
252
    using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
    using type1 = Sequence<>;

    static constexpr bool kHasContent =
        (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);

    using type = typename conditional<kHasContent, type0, type1>::type;
Chao Liu's avatar
Chao Liu committed
253
254
255
256
257
258
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
259
    struct F
Chao Liu's avatar
Chao Liu committed
260
261
262
263
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
264
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
265
266
267
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
268
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
269
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
270

Chao Liu's avatar
Chao Liu committed
271
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
272
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
273
{
Chao Liu's avatar
Chao Liu committed
274
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
275
276
277

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
278
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
279
280
};

Chao Liu's avatar
Chao Liu committed
281
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
282
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
283
{
Chao Liu's avatar
Chao Liu committed
284
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
285
286
};

Chao Liu's avatar
Chao Liu committed
287
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
288
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
289
{
Chao Liu's avatar
Chao Liu committed
290
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
291
292
};

Chao Liu's avatar
Chao Liu committed
293
// split sequence
Chao Liu's avatar
Chao Liu committed
294
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
295
296
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
297
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
298

Chao Liu's avatar
Chao Liu committed
299
300
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
301

Chao Liu's avatar
Chao Liu committed
302
303
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
304
305
};

Chao Liu's avatar
Chao Liu committed
306
// reverse sequence
Chao Liu's avatar
Chao Liu committed
307
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
308
309
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
310
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
311
312

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
313
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
314
315
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
316
317
318
319
320
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
321
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
322
323
324
325
326
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
327
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
328
};
Chao Liu's avatar
Chao Liu committed
329

Chao Liu's avatar
Chao Liu committed
330
#if 1
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
352
353
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
354
{
Chao Liu's avatar
Chao Liu committed
355
356
357
358
359
360
361
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
362
363
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
393
394
    };

Chao Liu's avatar
Chao Liu committed
395
396
397
398
399
400
401
402
403
404
405
406
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
407
    {
Chao Liu's avatar
Chao Liu committed
408
409
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
410
411
    };

Chao Liu's avatar
Chao Liu committed
412
413
414
415
416
417
418
419
420
421
422
423
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
424
    {
Chao Liu's avatar
Chao Liu committed
425
426
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
427
428
    };

Chao Liu's avatar
Chao Liu committed
429
430
431
432
433
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
434
435
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
436
437
438
439
440
441
442
443
444
445
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
446
447
    };

Chao Liu's avatar
Chao Liu committed
448
449
450
451
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
452

Chao Liu's avatar
Chao Liu committed
453
454
455
456
457
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
458

Chao Liu's avatar
Chao Liu committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
473
474
};

Chao Liu's avatar
Chao Liu committed
475
476
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
477
{
Chao Liu's avatar
Chao Liu committed
478
479
480
481
482
483
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
484

Chao Liu's avatar
Chao Liu committed
485
486
487
488
489
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
490
491
};

492
493
494
495
496
497
498
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
499
500
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
501
{
Chao Liu's avatar
Chao Liu committed
502
503
504
505
506
507
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
508
509
};

Chao Liu's avatar
Chao Liu committed
510
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
511
512
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
513
514
515
516
517
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
518
519
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
520
521
522
523
524
525
526
527
528
529
530
531
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
532

Chao Liu's avatar
Chao Liu committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
547
548
    };

Chao Liu's avatar
Chao Liu committed
549
550
551
552
553
554
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
555
    {
Chao Liu's avatar
Chao Liu committed
556
557
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
558
559
    };

Chao Liu's avatar
Chao Liu committed
560
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
561
562
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
563
564
565
566
567
568
569
570
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
571
572
    };

Chao Liu's avatar
Chao Liu committed
573
574
575
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
576

Chao Liu's avatar
Chao Liu committed
577
578
579
580
581
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
582
583
};

Chao Liu's avatar
Chao Liu committed
584
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
585
586
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
587
588
{
};
589

Chao Liu's avatar
Chao Liu committed
590
591
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
592
{
Chao Liu's avatar
Chao Liu committed
593
594
595
596
597
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
598

Chao Liu's avatar
Chao Liu committed
599
600
601
602
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
603

Chao Liu's avatar
Chao Liu committed
604
605
606
607
608
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
609
610

    using type =
Chao Liu's avatar
Chao Liu committed
611
612
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
613
                                           0,
Chao Liu's avatar
Chao Liu committed
614
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
615
616
};

617
618
619
620
621
622
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
623
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
624
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
625
626
627
628
629
630
631
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
632
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
633
634
635
636
637
638
639
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
640
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
641
642
643
644
645
646
647
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
648
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
649
650
651
652
653
654
655
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
656
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
657
658
659
660
661
662
663
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
664
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
665
{
Chao Liu's avatar
Chao Liu committed
666
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
667
668
669
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
670
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
671
{
Chao Liu's avatar
Chao Liu committed
672
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
673
674
675
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
676
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
677
{
Chao Liu's avatar
Chao Liu committed
678
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
679
680
681
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
682
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
683
{
Chao Liu's avatar
Chao Liu committed
684
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
685
686
687
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
688
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
689
{
Chao Liu's avatar
Chao Liu committed
690
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
691
692
}

Chao Liu's avatar
Chao Liu committed
693
694
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
695
{
Chao Liu's avatar
Chao Liu committed
696
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
697
698
}

Chao Liu's avatar
Chao Liu committed
699
700
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
701
{
Chao Liu's avatar
Chao Liu committed
702
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
703
704
}

Chao Liu's avatar
Chao Liu committed
705
706
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
707
{
Chao Liu's avatar
Chao Liu committed
708
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
709
710
}

Chao Liu's avatar
Chao Liu committed
711
712
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
713
{
Chao Liu's avatar
Chao Liu committed
714
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
715
716
}

Chao Liu's avatar
Chao Liu committed
717
718
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
719
{
Chao Liu's avatar
Chao Liu committed
720
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
721
722
}

723
724
725
726
727
728
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
729
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
730
__host__ __device__ constexpr auto sequence_pop_back(Seq)
731
{
Chao Liu's avatar
Chao Liu committed
732
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
733
    return sequence_pop_front(Seq::Reverse()).Reverse();
734
}
735

Chao Liu's avatar
Chao Liu committed
736
737
738
739
740
741
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

742
743
744
745
746
747
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
748
template <typename F, index_t... Xs, index_t... Ys>
749
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
750
{
751
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
752
753
754
755

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
756
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
757
758
759
760
761
762
763
764
765
766
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
767
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
768
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
769
{
Chao Liu's avatar
Chao Liu committed
770
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
771
772
}

Chao Liu's avatar
Chao Liu committed
773
774
775
776
777
778
779
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
780
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
781
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
782
{
Chao Liu's avatar
Chao Liu committed
783
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
784
}
785

Chao Liu's avatar
Chao Liu committed
786
template <typename Seq, index_t... Is>
787
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
788
789
790
791
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

815
816
817
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
818
819
820
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
821
822
}

Chao Liu's avatar
Chao Liu committed
823
824
825
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
826
{
Chao Liu's avatar
Chao Liu committed
827
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
828

Chao Liu's avatar
Chao Liu committed
829
830
831
832
833
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
834

Chao Liu's avatar
Chao Liu committed
835
836
837
838
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
839
};
Chao Liu's avatar
Chao Liu committed
840
841
842
843
844
845
846
847
848
849
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
850

Chao Liu's avatar
Chao Liu committed
851
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
852
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
853
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
854
855
856
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
857
858
859
860
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
861
862
863
864

    return result;
}

Chao Liu's avatar
Chao Liu committed
865
866
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
867
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
868
869
870
871
872
873
874
875
876
877
878
879
880
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
881
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
882
883
884
885
886
887
888
889
890
891
892
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

893
894
895
896
897
898
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;

template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;

899
} // namespace ck