sequence.hpp 28.9 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
arai713's avatar
arai713 committed
2
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
#pragma once
5

arai713's avatar
arai713 committed
6
#ifndef CK_CODE_GEN_RTC
7
#include <ostream>
arai713's avatar
arai713 committed
8
#endif
9

10
11
12
13
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
14

15
16
namespace ck {

Chao Liu's avatar
Chao Liu committed
17
18
19
template <index_t, index_t, index_t>
struct static_for;

20
21
22
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
23
template <typename Seq, index_t I>
24
25
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
26
template <typename>
27
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
28

Chao Liu's avatar
Chao Liu committed
29
template <typename>
Chao Liu's avatar
Chao Liu committed
30
31
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
32
template <typename>
33
34
35
36
37
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
38
template <typename Seq>
39
40
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
41
template <index_t... Is>
42
43
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
44
45
    using Type      = Sequence;
    using data_type = index_t;
46

47
    static constexpr index_t mSize = sizeof...(Is);
48

Chao Liu's avatar
Chao Liu committed
49
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
50

Chao Liu's avatar
Chao Liu committed
51
52
53
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
54
    {
Chao Liu's avatar
Chao Liu committed
55
56
57
58
59
60
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
61
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
62
    {
Chao Liu's avatar
Chao Liu committed
63
64
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
65
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
66
67
    }

Chao Liu's avatar
Chao Liu committed
68
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
69
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
70
    {
Chao Liu's avatar
Chao Liu committed
71
        return At(Number<I>{});
72
73
    }

Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
79

80
    template <index_t... IRs>
81
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
82
    {
Chao Liu's avatar
Chao Liu committed
83
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
84
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
85

Chao Liu's avatar
Chao Liu committed
86
87
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
88
        return Sequence<Type::At(Number<IRs>{})...>{};
89
90
    }

Chao Liu's avatar
Chao Liu committed
91
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
92
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
93
94
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
95
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
101
102
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

103
104
105
106
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
107

Chao Liu's avatar
Chao Liu committed
108
    __host__ __device__ static constexpr auto Front()
109
    {
Chao Liu's avatar
Chao Liu committed
110
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
111
        return At(Number<0>{});
112
    }
113

Chao Liu's avatar
Chao Liu committed
114
    __host__ __device__ static constexpr auto Back()
115
    {
Chao Liu's avatar
Chao Liu committed
116
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
117
        return At(Number<mSize - 1>{});
118
    }
119

120
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
121

122
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
123
124
125

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
126
    {
Chao Liu's avatar
Chao Liu committed
127
        return Sequence<Xs..., Is...>{};
128
129
    }

Chao Liu's avatar
Chao Liu committed
130
131
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
132
    {
Chao Liu's avatar
Chao Liu committed
133
        return Sequence<Xs..., Is...>{};
134
135
    }

Chao Liu's avatar
Chao Liu committed
136
137
138
139
140
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
141

Chao Liu's avatar
Chao Liu committed
142
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
143
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
144
    {
Chao Liu's avatar
Chao Liu committed
145
146
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
    template <index_t... Ns>
149
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
150
    {
Chao Liu's avatar
Chao Liu committed
151
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
152
    }
Chao Liu's avatar
Chao Liu committed
153

Chao Liu's avatar
Chao Liu committed
154
    template <index_t... Ns>
155
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
156
    {
Chao Liu's avatar
Chao Liu committed
157
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
158
    }
159
160

    template <index_t I, index_t X>
161
162
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
163
        static_assert(I < Size(), "wrong!");
164
165

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
166
167
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
168
169
170

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
171

Chao Liu's avatar
Chao Liu committed
172
    template <typename F>
Chao Liu's avatar
Chao Liu committed
173
174
175
176
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
177
178
179
180
181
182
183
184

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
185
186
};

Chao Liu's avatar
Chao Liu committed
187
// merge sequence
Chao Liu's avatar
Chao Liu committed
188
189
190
191
192
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
193

Chao Liu's avatar
Chao Liu committed
194
195
196
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
197
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
198
};
Chao Liu's avatar
Chao Liu committed
199

Chao Liu's avatar
Chao Liu committed
200
201
202
203
204
205
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
206
// generate sequence
Chao Liu's avatar
Chao Liu committed
207
208
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
209
{
Chao Liu's avatar
Chao Liu committed
210
211
212
213
214
215
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
216

Chao Liu's avatar
Chao Liu committed
217
218
219
220
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
221

Chao Liu's avatar
Chao Liu committed
222
223
224
225
226
227
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
228

Chao Liu's avatar
Chao Liu committed
229
230
231
232
233
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
234

Chao Liu's avatar
Chao Liu committed
235
236
237
238
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
239
template <index_t IBegin, index_t IEnd, index_t Increment>
240
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
241
{
Chao Liu's avatar
Chao Liu committed
242
243
244
245
246
247
248
249
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

250
251
252
253
254
255
256
    using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
    using type1 = Sequence<>;

    static constexpr bool kHasContent =
        (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);

    using type = typename conditional<kHasContent, type0, type1>::type;
Chao Liu's avatar
Chao Liu committed
257
258
259
260
261
262
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
263
    struct F
Chao Liu's avatar
Chao Liu committed
264
265
266
267
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
268
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
269
270
271
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
272
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
273
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
274

Chao Liu's avatar
Chao Liu committed
275
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
276
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
277
{
Chao Liu's avatar
Chao Liu committed
278
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
279
280
281

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
282
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
283
284
};

Chao Liu's avatar
Chao Liu committed
285
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
286
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
287
{
Chao Liu's avatar
Chao Liu committed
288
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
289
290
};

Chao Liu's avatar
Chao Liu committed
291
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
292
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
293
{
Chao Liu's avatar
Chao Liu committed
294
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
295
296
};

Chao Liu's avatar
Chao Liu committed
297
// split sequence
Chao Liu's avatar
Chao Liu committed
298
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
299
300
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
301
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
302

Chao Liu's avatar
Chao Liu committed
303
304
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
305

Chao Liu's avatar
Chao Liu committed
306
307
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
308
309
};

Chao Liu's avatar
Chao Liu committed
310
// reverse sequence
Chao Liu's avatar
Chao Liu committed
311
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
312
313
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
314
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
315
316

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
317
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
318
319
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
320
321
322
323
324
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
325
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
326
327
328
329
330
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
331
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
332
};
Chao Liu's avatar
Chao Liu committed
333

Chao Liu's avatar
Chao Liu committed
334
#if 1
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
356
357
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
358
{
Chao Liu's avatar
Chao Liu committed
359
360
361
362
363
364
365
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
366
367
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
397
398
    };

Chao Liu's avatar
Chao Liu committed
399
400
401
402
403
404
405
406
407
408
409
410
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
411
    {
Chao Liu's avatar
Chao Liu committed
412
413
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
414
415
    };

Chao Liu's avatar
Chao Liu committed
416
417
418
419
420
421
422
423
424
425
426
427
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
428
    {
Chao Liu's avatar
Chao Liu committed
429
430
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
431
432
    };

Chao Liu's avatar
Chao Liu committed
433
434
435
436
437
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
438
439
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
440
441
442
443
444
445
446
447
448
449
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
450
451
    };

Chao Liu's avatar
Chao Liu committed
452
453
454
455
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
456

Chao Liu's avatar
Chao Liu committed
457
458
459
460
461
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
462

Chao Liu's avatar
Chao Liu committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
477
478
};

Chao Liu's avatar
Chao Liu committed
479
480
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
481
{
Chao Liu's avatar
Chao Liu committed
482
483
484
485
486
487
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
488

Chao Liu's avatar
Chao Liu committed
489
490
491
492
493
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
494
495
};

496
497
498
499
500
501
502
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
503
504
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
505
{
Chao Liu's avatar
Chao Liu committed
506
507
508
509
510
511
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
512
513
};

Chao Liu's avatar
Chao Liu committed
514
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
515
516
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
517
518
519
520
521
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
522
523
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
524
525
526
527
528
529
530
531
532
533
534
535
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
536

Chao Liu's avatar
Chao Liu committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
551
552
    };

Chao Liu's avatar
Chao Liu committed
553
554
555
556
557
558
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
559
    {
Chao Liu's avatar
Chao Liu committed
560
561
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
562
563
    };

Chao Liu's avatar
Chao Liu committed
564
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
565
566
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
567
568
569
570
571
572
573
574
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
575
576
    };

Chao Liu's avatar
Chao Liu committed
577
578
579
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
580

Chao Liu's avatar
Chao Liu committed
581
582
583
584
585
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
586
587
};

Chao Liu's avatar
Chao Liu committed
588
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
589
590
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
591
592
{
};
593

Chao Liu's avatar
Chao Liu committed
594
595
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
596
{
Chao Liu's avatar
Chao Liu committed
597
598
599
600
601
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
602

Chao Liu's avatar
Chao Liu committed
603
604
605
606
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
607

Chao Liu's avatar
Chao Liu committed
608
609
610
611
612
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
613
614

    using type =
Chao Liu's avatar
Chao Liu committed
615
616
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
617
                                           0,
Chao Liu's avatar
Chao Liu committed
618
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
619
620
};

621
622
623
624
625
626
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
627
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
628
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
629
630
631
632
633
634
635
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
636
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
637
638
639
640
641
642
643
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
644
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
645
646
647
648
649
650
651
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
652
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
653
654
655
656
657
658
659
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
660
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
661
662
663
664
665
666
667
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
668
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
669
{
Chao Liu's avatar
Chao Liu committed
670
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
671
672
673
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
674
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
675
{
Chao Liu's avatar
Chao Liu committed
676
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
677
678
679
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
680
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
681
{
Chao Liu's avatar
Chao Liu committed
682
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
683
684
685
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
686
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
687
{
Chao Liu's avatar
Chao Liu committed
688
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
689
690
691
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
692
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
693
{
Chao Liu's avatar
Chao Liu committed
694
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
695
696
}

Chao Liu's avatar
Chao Liu committed
697
698
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
699
{
Chao Liu's avatar
Chao Liu committed
700
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
701
702
}

Chao Liu's avatar
Chao Liu committed
703
704
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
705
{
Chao Liu's avatar
Chao Liu committed
706
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
707
708
}

Chao Liu's avatar
Chao Liu committed
709
710
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
711
{
Chao Liu's avatar
Chao Liu committed
712
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
713
714
}

Chao Liu's avatar
Chao Liu committed
715
716
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
717
{
Chao Liu's avatar
Chao Liu committed
718
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
719
720
}

Chao Liu's avatar
Chao Liu committed
721
722
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
723
{
Chao Liu's avatar
Chao Liu committed
724
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
725
726
}

727
728
729
730
731
732
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
733
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
734
__host__ __device__ constexpr auto sequence_pop_back(Seq)
735
{
Chao Liu's avatar
Chao Liu committed
736
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
737
    return sequence_pop_front(Seq::Reverse()).Reverse();
738
}
739

Chao Liu's avatar
Chao Liu committed
740
741
742
743
744
745
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

746
747
748
749
750
751
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
752
template <typename F, index_t... Xs, index_t... Ys>
753
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
754
{
755
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
756
757
758
759

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
760
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
761
762
763
764
765
766
767
768
769
770
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
771
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
772
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
773
{
Chao Liu's avatar
Chao Liu committed
774
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
775
776
}

Chao Liu's avatar
Chao Liu committed
777
778
779
780
781
782
783
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
784
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
785
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
786
{
Chao Liu's avatar
Chao Liu committed
787
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
788
}
789

Chao Liu's avatar
Chao Liu committed
790
template <typename Seq, index_t... Is>
791
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
792
793
794
795
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

819
820
821
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
822
823
824
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
825
826
}

Chao Liu's avatar
Chao Liu committed
827
828
829
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
830
{
Chao Liu's avatar
Chao Liu committed
831
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
832

Chao Liu's avatar
Chao Liu committed
833
834
835
836
837
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
838

Chao Liu's avatar
Chao Liu committed
839
840
841
842
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
843
};
Chao Liu's avatar
Chao Liu committed
844
845
846
847
848
849
850
851
852
853
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
854

Chao Liu's avatar
Chao Liu committed
855
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
856
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
857
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
858
859
860
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
861
862
863
864
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
865
866
867
868

    return result;
}

Chao Liu's avatar
Chao Liu committed
869
870
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
871
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
872
873
874
875
876
877
878
879
880
881
882
883
884
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
885
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
886
887
888
889
890
891
892
893
894
895
896
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

897
898
899
900
901
902
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;

template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;

903
} // namespace ck
904

arai713's avatar
arai713 committed
905
#ifndef CK_CODE_GEN_RTC
906
907
908
909
910
911
912
913
914
915
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
    using S = ck::Sequence<Is...>;
    os << "{";
    ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
        [&](auto i) { os << S::At(i).value << ", "; });
    os << S::At(S::Size() - ck::Number<1>{}).value << "}";
    return os;
}
arai713's avatar
arai713 committed
916
#endif