sequence.hpp 28.8 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
#pragma once
5

6
7
#include <ostream>

8
9
10
11
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
12

13
14
namespace ck {

Chao Liu's avatar
Chao Liu committed
15
16
17
template <index_t, index_t, index_t>
struct static_for;

18
19
20
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
21
template <typename Seq, index_t I>
22
23
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
24
template <typename>
25
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
26

Chao Liu's avatar
Chao Liu committed
27
template <typename>
Chao Liu's avatar
Chao Liu committed
28
29
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
30
template <typename>
31
32
33
34
35
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
36
template <typename Seq>
37
38
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
39
template <index_t... Is>
40
41
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
42
43
    using Type      = Sequence;
    using data_type = index_t;
44

45
    static constexpr index_t mSize = sizeof...(Is);
46

Chao Liu's avatar
Chao Liu committed
47
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
48

Chao Liu's avatar
Chao Liu committed
49
50
51
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
52
    {
Chao Liu's avatar
Chao Liu committed
53
54
55
56
57
58
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
59
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
60
    {
Chao Liu's avatar
Chao Liu committed
61
62
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
63
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
64
65
    }

Chao Liu's avatar
Chao Liu committed
66
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
67
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
68
    {
Chao Liu's avatar
Chao Liu committed
69
        return At(Number<I>{});
70
71
    }

Chao Liu's avatar
Chao Liu committed
72
73
74
75
76
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
77

78
    template <index_t... IRs>
79
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
80
    {
Chao Liu's avatar
Chao Liu committed
81
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
82
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
83

Chao Liu's avatar
Chao Liu committed
84
85
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
86
        return Sequence<Type::At(Number<IRs>{})...>{};
87
88
    }

Chao Liu's avatar
Chao Liu committed
89
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
90
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
91
92
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
93
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
94
95
96
97
98
99
100
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

101
102
103
104
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
    __host__ __device__ static constexpr auto Front()
107
    {
Chao Liu's avatar
Chao Liu committed
108
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
109
        return At(Number<0>{});
110
    }
111

Chao Liu's avatar
Chao Liu committed
112
    __host__ __device__ static constexpr auto Back()
113
    {
Chao Liu's avatar
Chao Liu committed
114
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
115
        return At(Number<mSize - 1>{});
116
    }
117

118
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
119

120
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
121
122
123

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
124
    {
Chao Liu's avatar
Chao Liu committed
125
        return Sequence<Xs..., Is...>{};
126
127
    }

Chao Liu's avatar
Chao Liu committed
128
129
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
130
    {
Chao Liu's avatar
Chao Liu committed
131
        return Sequence<Xs..., Is...>{};
132
133
    }

Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
139

Chao Liu's avatar
Chao Liu committed
140
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
141
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
142
    {
Chao Liu's avatar
Chao Liu committed
143
144
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
    template <index_t... Ns>
147
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
148
    {
Chao Liu's avatar
Chao Liu committed
149
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
150
    }
Chao Liu's avatar
Chao Liu committed
151

Chao Liu's avatar
Chao Liu committed
152
    template <index_t... Ns>
153
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
154
    {
Chao Liu's avatar
Chao Liu committed
155
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
156
    }
157
158

    template <index_t I, index_t X>
159
160
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
161
        static_assert(I < Size(), "wrong!");
162
163

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
164
165
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
166
167
168

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
169

Chao Liu's avatar
Chao Liu committed
170
    template <typename F>
Chao Liu's avatar
Chao Liu committed
171
172
173
174
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
175
176
177
178
179
180
181
182

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
183
184
};

Chao Liu's avatar
Chao Liu committed
185
// merge sequence
Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
191

Chao Liu's avatar
Chao Liu committed
192
193
194
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
195
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
196
};
Chao Liu's avatar
Chao Liu committed
197

Chao Liu's avatar
Chao Liu committed
198
199
200
201
202
203
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
204
// generate sequence
Chao Liu's avatar
Chao Liu committed
205
206
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
207
{
Chao Liu's avatar
Chao Liu committed
208
209
210
211
212
213
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
214

Chao Liu's avatar
Chao Liu committed
215
216
217
218
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
219

Chao Liu's avatar
Chao Liu committed
220
221
222
223
224
225
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
226

Chao Liu's avatar
Chao Liu committed
227
228
229
230
231
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
232

Chao Liu's avatar
Chao Liu committed
233
234
235
236
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
237
template <index_t IBegin, index_t IEnd, index_t Increment>
238
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
239
{
Chao Liu's avatar
Chao Liu committed
240
241
242
243
244
245
246
247
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

248
249
250
251
252
253
254
    using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
    using type1 = Sequence<>;

    static constexpr bool kHasContent =
        (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);

    using type = typename conditional<kHasContent, type0, type1>::type;
Chao Liu's avatar
Chao Liu committed
255
256
257
258
259
260
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
261
    struct F
Chao Liu's avatar
Chao Liu committed
262
263
264
265
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
266
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
267
268
269
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
270
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
271
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
272

Chao Liu's avatar
Chao Liu committed
273
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
274
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
275
{
Chao Liu's avatar
Chao Liu committed
276
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
277
278
279

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
280
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
281
282
};

Chao Liu's avatar
Chao Liu committed
283
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
284
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
285
{
Chao Liu's avatar
Chao Liu committed
286
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
287
288
};

Chao Liu's avatar
Chao Liu committed
289
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
290
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
291
{
Chao Liu's avatar
Chao Liu committed
292
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
293
294
};

Chao Liu's avatar
Chao Liu committed
295
// split sequence
Chao Liu's avatar
Chao Liu committed
296
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
297
298
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
299
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
300

Chao Liu's avatar
Chao Liu committed
301
302
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
303

Chao Liu's avatar
Chao Liu committed
304
305
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
306
307
};

Chao Liu's avatar
Chao Liu committed
308
// reverse sequence
Chao Liu's avatar
Chao Liu committed
309
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
310
311
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
312
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
313
314

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
315
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
316
317
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
318
319
320
321
322
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
323
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
324
325
326
327
328
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
329
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
330
};
Chao Liu's avatar
Chao Liu committed
331

Chao Liu's avatar
Chao Liu committed
332
#if 1
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
354
355
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
356
{
Chao Liu's avatar
Chao Liu committed
357
358
359
360
361
362
363
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
364
365
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
395
396
    };

Chao Liu's avatar
Chao Liu committed
397
398
399
400
401
402
403
404
405
406
407
408
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
409
    {
Chao Liu's avatar
Chao Liu committed
410
411
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
412
413
    };

Chao Liu's avatar
Chao Liu committed
414
415
416
417
418
419
420
421
422
423
424
425
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
426
    {
Chao Liu's avatar
Chao Liu committed
427
428
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
429
430
    };

Chao Liu's avatar
Chao Liu committed
431
432
433
434
435
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
436
437
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
438
439
440
441
442
443
444
445
446
447
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
448
449
    };

Chao Liu's avatar
Chao Liu committed
450
451
452
453
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
454

Chao Liu's avatar
Chao Liu committed
455
456
457
458
459
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
460

Chao Liu's avatar
Chao Liu committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
475
476
};

Chao Liu's avatar
Chao Liu committed
477
478
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
479
{
Chao Liu's avatar
Chao Liu committed
480
481
482
483
484
485
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
486

Chao Liu's avatar
Chao Liu committed
487
488
489
490
491
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
492
493
};

494
495
496
497
498
499
500
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
501
502
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
503
{
Chao Liu's avatar
Chao Liu committed
504
505
506
507
508
509
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
510
511
};

Chao Liu's avatar
Chao Liu committed
512
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
513
514
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
515
516
517
518
519
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
520
521
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
522
523
524
525
526
527
528
529
530
531
532
533
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
534

Chao Liu's avatar
Chao Liu committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
549
550
    };

Chao Liu's avatar
Chao Liu committed
551
552
553
554
555
556
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
557
    {
Chao Liu's avatar
Chao Liu committed
558
559
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
560
561
    };

Chao Liu's avatar
Chao Liu committed
562
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
563
564
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
565
566
567
568
569
570
571
572
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
573
574
    };

Chao Liu's avatar
Chao Liu committed
575
576
577
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
578

Chao Liu's avatar
Chao Liu committed
579
580
581
582
583
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
584
585
};

Chao Liu's avatar
Chao Liu committed
586
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
587
588
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
589
590
{
};
591

Chao Liu's avatar
Chao Liu committed
592
593
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
594
{
Chao Liu's avatar
Chao Liu committed
595
596
597
598
599
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
600

Chao Liu's avatar
Chao Liu committed
601
602
603
604
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
605

Chao Liu's avatar
Chao Liu committed
606
607
608
609
610
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
611
612

    using type =
Chao Liu's avatar
Chao Liu committed
613
614
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
615
                                           0,
Chao Liu's avatar
Chao Liu committed
616
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
617
618
};

619
620
621
622
623
624
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
625
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
626
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
627
628
629
630
631
632
633
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
634
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
635
636
637
638
639
640
641
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
642
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
643
644
645
646
647
648
649
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
650
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
651
652
653
654
655
656
657
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
658
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
659
660
661
662
663
664
665
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
666
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
667
{
Chao Liu's avatar
Chao Liu committed
668
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
669
670
671
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
672
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
673
{
Chao Liu's avatar
Chao Liu committed
674
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
675
676
677
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
678
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
679
{
Chao Liu's avatar
Chao Liu committed
680
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
681
682
683
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
684
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
685
{
Chao Liu's avatar
Chao Liu committed
686
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
687
688
689
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
690
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
691
{
Chao Liu's avatar
Chao Liu committed
692
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
693
694
}

Chao Liu's avatar
Chao Liu committed
695
696
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
697
{
Chao Liu's avatar
Chao Liu committed
698
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
699
700
}

Chao Liu's avatar
Chao Liu committed
701
702
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
703
{
Chao Liu's avatar
Chao Liu committed
704
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
705
706
}

Chao Liu's avatar
Chao Liu committed
707
708
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
709
{
Chao Liu's avatar
Chao Liu committed
710
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
711
712
}

Chao Liu's avatar
Chao Liu committed
713
714
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
715
{
Chao Liu's avatar
Chao Liu committed
716
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
717
718
}

Chao Liu's avatar
Chao Liu committed
719
720
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
721
{
Chao Liu's avatar
Chao Liu committed
722
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
723
724
}

725
726
727
728
729
730
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
731
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
732
__host__ __device__ constexpr auto sequence_pop_back(Seq)
733
{
Chao Liu's avatar
Chao Liu committed
734
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
735
    return sequence_pop_front(Seq::Reverse()).Reverse();
736
}
737

Chao Liu's avatar
Chao Liu committed
738
739
740
741
742
743
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

744
745
746
747
748
749
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
750
template <typename F, index_t... Xs, index_t... Ys>
751
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
752
{
753
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
754
755
756
757

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
758
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
759
760
761
762
763
764
765
766
767
768
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
769
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
770
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
771
{
Chao Liu's avatar
Chao Liu committed
772
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
773
774
}

Chao Liu's avatar
Chao Liu committed
775
776
777
778
779
780
781
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
782
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
783
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
784
{
Chao Liu's avatar
Chao Liu committed
785
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
786
}
787

Chao Liu's avatar
Chao Liu committed
788
template <typename Seq, index_t... Is>
789
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
790
791
792
793
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

817
818
819
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
820
821
822
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
823
824
}

Chao Liu's avatar
Chao Liu committed
825
826
827
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
828
{
Chao Liu's avatar
Chao Liu committed
829
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
830

Chao Liu's avatar
Chao Liu committed
831
832
833
834
835
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
836

Chao Liu's avatar
Chao Liu committed
837
838
839
840
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
841
};
Chao Liu's avatar
Chao Liu committed
842
843
844
845
846
847
848
849
850
851
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
852

Chao Liu's avatar
Chao Liu committed
853
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
854
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
855
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
856
857
858
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
859
860
861
862
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
863
864
865
866

    return result;
}

Chao Liu's avatar
Chao Liu committed
867
868
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
869
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
870
871
872
873
874
875
876
877
878
879
880
881
882
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
883
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
884
885
886
887
888
889
890
891
892
893
894
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

895
896
897
898
899
900
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;

template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;

901
} // namespace ck
902
903
904
905
906
907
908
909
910
911
912

template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
    using S = ck::Sequence<Is...>;
    os << "{";
    ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
        [&](auto i) { os << S::At(i).value << ", "; });
    os << S::At(S::Size() - ck::Number<1>{}).value << "}";
    return os;
}