sequence.hpp 28.6 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
Chao Liu's avatar
Chao Liu committed
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
6

7
#pragma once
8

9
10
11
12
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
13

14
15
namespace ck {

Chao Liu's avatar
Chao Liu committed
16
17
18
template <index_t, index_t, index_t>
struct static_for;

19
20
21
template <index_t...>
struct Sequence;

Chao Liu's avatar
Chao Liu committed
22
template <typename Seq, index_t I>
23
24
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
25
template <typename>
26
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
27

Chao Liu's avatar
Chao Liu committed
28
template <typename>
Chao Liu's avatar
Chao Liu committed
29
30
struct sequence_map_inverse;

Chao Liu's avatar
Chao Liu committed
31
template <typename>
32
33
34
35
36
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

Chao Liu's avatar
Chao Liu committed
37
template <typename Seq>
38
39
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
40
template <index_t... Is>
41
42
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
43
44
    using Type      = Sequence;
    using data_type = index_t;
45

46
    static constexpr index_t mSize = sizeof...(Is);
47

Chao Liu's avatar
Chao Liu committed
48
    __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
49

Chao Liu's avatar
Chao Liu committed
50
51
52
    __host__ __device__ static constexpr auto GetSize() { return Size(); }

    __host__ __device__ static constexpr index_t At(index_t I)
53
    {
Chao Liu's avatar
Chao Liu committed
54
55
56
57
58
59
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
60
    __host__ __device__ static constexpr auto At(Number<I>)
Chao Liu's avatar
Chao Liu committed
61
    {
Chao Liu's avatar
Chao Liu committed
62
63
        static_assert(I < mSize, "wrong! I too large");

Chao Liu's avatar
Chao Liu committed
64
        return Number<At(I)>{};
Chao Liu's avatar
Chao Liu committed
65
66
    }

Chao Liu's avatar
Chao Liu committed
67
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
68
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
69
    {
Chao Liu's avatar
Chao Liu committed
70
        return At(Number<I>{});
71
72
    }

Chao Liu's avatar
Chao Liu committed
73
74
75
76
77
    template <typename I>
    __host__ __device__ constexpr auto operator[](I i) const
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
78

79
    template <index_t... IRs>
80
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
81
    {
Chao Liu's avatar
Chao Liu committed
82
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
83
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
84

Chao Liu's avatar
Chao Liu committed
85
86
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

Chao Liu's avatar
Chao Liu committed
87
        return Sequence<Type::At(Number<IRs>{})...>{};
88
89
    }

Chao Liu's avatar
Chao Liu committed
90
    // MapOld2New is Sequence<...>
Chao Liu's avatar
Chao Liu committed
91
    template <typename MapOld2New>
Chao Liu's avatar
Chao Liu committed
92
93
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
Chao Liu's avatar
Chao Liu committed
94
        static_assert(MapOld2New::Size() == Size(),
Chao Liu's avatar
Chao Liu committed
95
96
97
98
99
100
101
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

102
103
104
105
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
106

Chao Liu's avatar
Chao Liu committed
107
    __host__ __device__ static constexpr auto Front()
108
    {
Chao Liu's avatar
Chao Liu committed
109
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
110
        return At(Number<0>{});
111
    }
112

Chao Liu's avatar
Chao Liu committed
113
    __host__ __device__ static constexpr auto Back()
114
    {
Chao Liu's avatar
Chao Liu committed
115
        static_assert(mSize > 0, "wrong!");
Chao Liu's avatar
Chao Liu committed
116
        return At(Number<mSize - 1>{});
117
    }
118

119
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
120

121
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
122
123
124

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
125
    {
Chao Liu's avatar
Chao Liu committed
126
        return Sequence<Xs..., Is...>{};
127
128
    }

Chao Liu's avatar
Chao Liu committed
129
130
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
131
    {
Chao Liu's avatar
Chao Liu committed
132
        return Sequence<Xs..., Is...>{};
133
134
    }

Chao Liu's avatar
Chao Liu committed
135
136
137
138
139
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
140

Chao Liu's avatar
Chao Liu committed
141
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
142
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
143
    {
Chao Liu's avatar
Chao Liu committed
144
145
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
146

Chao Liu's avatar
Chao Liu committed
147
    template <index_t... Ns>
148
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
149
    {
Chao Liu's avatar
Chao Liu committed
150
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
151
    }
Chao Liu's avatar
Chao Liu committed
152

Chao Liu's avatar
Chao Liu committed
153
    template <index_t... Ns>
154
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
155
    {
Chao Liu's avatar
Chao Liu committed
156
        return Sequence<Type::At(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
157
    }
158
159

    template <index_t I, index_t X>
160
161
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
Chao Liu's avatar
Chao Liu committed
162
        static_assert(I < Size(), "wrong!");
163
164

        using seq_split          = sequence_split<Type, I>;
Chao Liu's avatar
Chao Liu committed
165
166
        constexpr auto seq_left  = typename seq_split::left_type{};
        constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
167
168
169

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
170

Chao Liu's avatar
Chao Liu committed
171
    template <typename F>
Chao Liu's avatar
Chao Liu committed
172
173
174
175
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
Chao Liu's avatar
Chao Liu committed
176
177
178
179
180
181
182
183

    __host__ __device__ static void Print()
    {
        printf("{");
        printf("size %d, ", index_t{Size()});
        static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
        printf("}");
    }
184
185
};

Chao Liu's avatar
Chao Liu committed
186
// merge sequence
Chao Liu's avatar
Chao Liu committed
187
188
189
190
191
template <typename Seq, typename... Seqs>
struct sequence_merge
{
    using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
Chao Liu's avatar
Chao Liu committed
192

Chao Liu's avatar
Chao Liu committed
193
194
195
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
196
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
197
};
Chao Liu's avatar
Chao Liu committed
198

Chao Liu's avatar
Chao Liu committed
199
200
201
202
203
204
template <typename Seq>
struct sequence_merge<Seq>
{
    using type = Seq;
};

Chao Liu's avatar
Chao Liu committed
205
// generate sequence
Chao Liu's avatar
Chao Liu committed
206
207
template <index_t NSize, typename F>
struct sequence_gen
Chao Liu's avatar
Chao Liu committed
208
{
Chao Liu's avatar
Chao Liu committed
209
210
211
212
213
214
    template <index_t IBegin, index_t NRemain, typename G>
    struct sequence_gen_impl
    {
        static constexpr index_t NRemainLeft  = NRemain / 2;
        static constexpr index_t NRemainRight = NRemain - NRemainLeft;
        static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
215

Chao Liu's avatar
Chao Liu committed
216
217
218
219
        using type = typename sequence_merge<
            typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
            typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
    };
Chao Liu's avatar
Chao Liu committed
220

Chao Liu's avatar
Chao Liu committed
221
222
223
224
225
226
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 1, G>
    {
        static constexpr index_t Is = G{}(Number<I>{});
        using type                  = Sequence<Is>;
    };
Chao Liu's avatar
Chao Liu committed
227

Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
    template <index_t I, typename G>
    struct sequence_gen_impl<I, 0, G>
    {
        using type = Sequence<>;
    };
Chao Liu's avatar
Chao Liu committed
233

Chao Liu's avatar
Chao Liu committed
234
235
236
237
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
238
template <index_t IBegin, index_t IEnd, index_t Increment>
239
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
240
{
Chao Liu's avatar
Chao Liu committed
241
242
243
244
245
246
247
248
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

249
250
251
252
253
254
255
    using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
    using type1 = Sequence<>;

    static constexpr bool kHasContent =
        (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);

    using type = typename conditional<kHasContent, type0, type1>::type;
Chao Liu's avatar
Chao Liu committed
256
257
258
259
260
261
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
262
    struct F
Chao Liu's avatar
Chao Liu committed
263
264
265
266
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
267
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
268
269
270
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
271
template <typename, typename, index_t>
Chao Liu's avatar
Chao Liu committed
272
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
273

Chao Liu's avatar
Chao Liu committed
274
template <index_t I, index_t... Is, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
275
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
276
{
Chao Liu's avatar
Chao Liu committed
277
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
278
279
280

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
281
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
282
283
};

Chao Liu's avatar
Chao Liu committed
284
template <index_t I, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
285
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
286
{
Chao Liu's avatar
Chao Liu committed
287
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
288
289
};

Chao Liu's avatar
Chao Liu committed
290
template <typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
291
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
292
{
Chao Liu's avatar
Chao Liu committed
293
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
294
295
};

Chao Liu's avatar
Chao Liu committed
296
// split sequence
Chao Liu's avatar
Chao Liu committed
297
template <typename Seq, index_t I>
Chao Liu's avatar
Chao Liu committed
298
299
struct sequence_split
{
Chao Liu's avatar
Chao Liu committed
300
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
301

Chao Liu's avatar
Chao Liu committed
302
303
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
304

Chao Liu's avatar
Chao Liu committed
305
306
    using left_type  = decltype(Seq::Extract(range0{}));
    using right_type = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
307
308
};

Chao Liu's avatar
Chao Liu committed
309
// reverse sequence
Chao Liu's avatar
Chao Liu committed
310
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
311
312
struct sequence_reverse
{
Chao Liu's avatar
Chao Liu committed
313
    static constexpr index_t NSize = Seq{}.Size();
Chao Liu's avatar
Chao Liu committed
314
315

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
316
    using type      = typename sequence_merge<
Chao Liu's avatar
Chao Liu committed
317
318
        typename sequence_reverse<typename seq_split::right_type>::type,
        typename sequence_reverse<typename seq_split::left_type>::type>::type;
Chao Liu's avatar
Chao Liu committed
319
320
321
322
323
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
324
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
325
326
327
328
329
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
330
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
331
};
Chao Liu's avatar
Chao Liu committed
332

Chao Liu's avatar
Chao Liu committed
333
#if 1
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
    using type = typename sequence_reduce<Reduce,
                                          Seq,
                                          typename sequence_reduce<Reduce, Seqs...>::type>::type;
};

template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
    using type = Sequence<Reduce{}(Xs, Ys)...>;
};

template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
    using type = Seq;
};
#endif

Chao Liu's avatar
Chao Liu committed
355
356
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
Chao Liu's avatar
Chao Liu committed
357
{
Chao Liu's avatar
Chao Liu committed
358
359
360
361
362
363
364
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
365
366
    struct sorted_sequence_merge_impl
    {
Chao Liu's avatar
Chao Liu committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();

        static constexpr index_t chosen_value =
            choose_left ? LeftValues::Front() : RightValues::Front();
        static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();

        using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
        using new_merged_ids    = decltype(MergedIds::PushBack(Number<chosen_id>{}));

        using new_left_values =
            typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
        using new_left_ids =
            typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;

        using new_right_values =
            typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
        using new_right_ids =
            typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;

        using merge = sorted_sequence_merge_impl<new_left_values,
                                                 new_left_ids,
                                                 new_right_values,
                                                 new_right_ids,
                                                 new_merged_values,
                                                 new_merged_ids,
                                                 Comp>;
        // this is output
        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
396
397
    };

Chao Liu's avatar
Chao Liu committed
398
399
400
401
402
403
404
405
406
407
408
409
    template <typename LeftValues,
              typename LeftIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<LeftValues,
                                      LeftIds,
                                      Sequence<>,
                                      Sequence<>,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
410
    {
Chao Liu's avatar
Chao Liu committed
411
412
        using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, LeftIds>::type;
Chao Liu's avatar
Chao Liu committed
413
414
    };

Chao Liu's avatar
Chao Liu committed
415
416
417
418
419
420
421
422
423
424
425
426
    template <typename RightValues,
              typename RightIds,
              typename MergedValues,
              typename MergedIds,
              typename Comp>
    struct sorted_sequence_merge_impl<Sequence<>,
                                      Sequence<>,
                                      RightValues,
                                      RightIds,
                                      MergedValues,
                                      MergedIds,
                                      Comp>
Chao Liu's avatar
Chao Liu committed
427
    {
Chao Liu's avatar
Chao Liu committed
428
429
        using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
        using merged_ids    = typename sequence_merge<MergedIds, RightIds>::type;
Chao Liu's avatar
Chao Liu committed
430
431
    };

Chao Liu's avatar
Chao Liu committed
432
433
434
435
436
    template <typename LeftValues,
              typename LeftIds,
              typename RightValues,
              typename RightIds,
              typename Comp>
Chao Liu's avatar
Chao Liu committed
437
438
    struct sorted_sequence_merge
    {
Chao Liu's avatar
Chao Liu committed
439
440
441
442
443
444
445
446
447
448
        using merge = sorted_sequence_merge_impl<LeftValues,
                                                 LeftIds,
                                                 RightValues,
                                                 RightIds,
                                                 Sequence<>,
                                                 Sequence<>,
                                                 Comp>;

        using merged_values = typename merge::merged_values;
        using merged_ids    = typename merge::merged_ids;
Chao Liu's avatar
Chao Liu committed
449
450
    };

Chao Liu's avatar
Chao Liu committed
451
452
453
454
    static constexpr index_t nsize = Values::Size();

    using split_unsorted_values = sequence_split<Values, nsize / 2>;
    using split_unsorted_ids    = sequence_split<Ids, nsize / 2>;
Chao Liu's avatar
Chao Liu committed
455

Chao Liu's avatar
Chao Liu committed
456
457
458
459
460
    using left_unsorted_values = typename split_unsorted_values::left_type;
    using left_unsorted_ids    = typename split_unsorted_ids::left_type;
    using left_sort          = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
    using left_sorted_values = typename left_sort::sorted_values;
    using left_sorted_ids    = typename left_sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
461

Chao Liu's avatar
Chao Liu committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    using right_unsorted_values = typename split_unsorted_values::right_type;
    using right_unsorted_ids    = typename split_unsorted_ids::right_type;
    using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
    using right_sorted_values = typename right_sort::sorted_values;
    using right_sorted_ids    = typename right_sort::sorted_ids;

    using merged_sorted = sorted_sequence_merge<left_sorted_values,
                                                left_sorted_ids,
                                                right_sorted_values,
                                                right_sorted_ids,
                                                Compare>;

    using sorted_values = typename merged_sorted::merged_values;
    using sorted_ids    = typename merged_sorted::merged_ids;
Chao Liu's avatar
Chao Liu committed
476
477
};

Chao Liu's avatar
Chao Liu committed
478
479
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
Chao Liu's avatar
Chao Liu committed
480
{
Chao Liu's avatar
Chao Liu committed
481
482
483
484
485
486
    static constexpr bool choose_x = Compare{}(ValueX, ValueY);

    using sorted_values =
        typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
    using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
Chao Liu's avatar
Chao Liu committed
487

Chao Liu's avatar
Chao Liu committed
488
489
490
491
492
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
    using sorted_values = Sequence<Value>;
    using sorted_ids    = Sequence<Id>;
Chao Liu's avatar
Chao Liu committed
493
494
};

495
496
497
498
499
500
501
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
    using sorted_values = Sequence<>;
    using sorted_ids    = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
502
503
template <typename Values, typename Compare>
struct sequence_sort
Chao Liu's avatar
Chao Liu committed
504
{
Chao Liu's avatar
Chao Liu committed
505
506
507
508
509
510
    using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
    using sort         = sequence_sort_impl<Values, unsorted_ids, Compare>;

    // this is output
    using type                = typename sort::sorted_values;
    using sorted2unsorted_map = typename sort::sorted_ids;
Chao Liu's avatar
Chao Liu committed
511
512
};

Chao Liu's avatar
Chao Liu committed
513
template <typename Values, typename Less, typename Equal>
Chao Liu's avatar
Chao Liu committed
514
515
struct sequence_unique_sort
{
Chao Liu's avatar
Chao Liu committed
516
517
518
519
520
    template <typename RemainValues,
              typename RemainIds,
              typename UniquifiedValues,
              typename UniquifiedIds,
              typename Eq>
Chao Liu's avatar
Chao Liu committed
521
522
    struct sorted_sequence_uniquify_impl
    {
Chao Liu's avatar
Chao Liu committed
523
524
525
526
527
528
529
530
531
532
533
534
        static constexpr index_t current_value = RemainValues::Front();
        static constexpr index_t current_id    = RemainIds::Front();

        static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());

        using new_remain_values = decltype(RemainValues::PopFront());
        using new_remain_ids    = decltype(RemainIds::PopFront());

        using new_uniquified_values =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedValues::PushBack(Number<current_value>{})),
                                 UniquifiedValues>::type;
Chao Liu's avatar
Chao Liu committed
535

Chao Liu's avatar
Chao Liu committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        using new_uniquified_ids =
            typename conditional<is_unique_value,
                                 decltype(UniquifiedIds::PushBack(Number<current_id>{})),
                                 UniquifiedIds>::type;

        using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
                                                       new_remain_ids,
                                                       new_uniquified_values,
                                                       new_uniquified_ids,
                                                       Eq>;

        // this is output
        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
550
551
    };

Chao Liu's avatar
Chao Liu committed
552
553
554
555
556
557
    template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
    struct sorted_sequence_uniquify_impl<Sequence<>,
                                         Sequence<>,
                                         UniquifiedValues,
                                         UniquifiedIds,
                                         Eq>
Chao Liu's avatar
Chao Liu committed
558
    {
Chao Liu's avatar
Chao Liu committed
559
560
        using uniquified_values = UniquifiedValues;
        using uniquified_ids    = UniquifiedIds;
Chao Liu's avatar
Chao Liu committed
561
562
    };

Chao Liu's avatar
Chao Liu committed
563
    template <typename SortedValues, typename SortedIds, typename Eq>
Chao Liu's avatar
Chao Liu committed
564
565
    struct sorted_sequence_uniquify
    {
Chao Liu's avatar
Chao Liu committed
566
567
568
569
570
571
572
573
        using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
                                                       decltype(SortedIds::PopFront()),
                                                       Sequence<SortedValues::Front()>,
                                                       Sequence<SortedIds::Front()>,
                                                       Eq>;

        using uniquified_values = typename uniquify::uniquified_values;
        using uniquified_ids    = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
574
575
    };

Chao Liu's avatar
Chao Liu committed
576
577
578
    using sort          = sequence_sort<Values, Less>;
    using sorted_values = typename sort::type;
    using sorted_ids    = typename sort::sorted2unsorted_map;
Chao Liu's avatar
Chao Liu committed
579

Chao Liu's avatar
Chao Liu committed
580
581
582
583
584
    using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;

    // this is output
    using type                = typename uniquify::uniquified_values;
    using sorted2unsorted_map = typename uniquify::uniquified_ids;
Chao Liu's avatar
Chao Liu committed
585
586
};

Chao Liu's avatar
Chao Liu committed
587
template <typename SeqMap>
Chao Liu's avatar
Chao Liu committed
588
589
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
                                       typename sequence_sort<SeqMap, math::less<index_t>>::type>
Chao Liu's avatar
Chao Liu committed
590
591
{
};
592

Chao Liu's avatar
Chao Liu committed
593
594
template <typename SeqMap>
struct sequence_map_inverse
Chao Liu's avatar
Chao Liu committed
595
{
Chao Liu's avatar
Chao Liu committed
596
597
598
599
600
    template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
    struct sequence_map_inverse_impl
    {
        static constexpr auto new_y2x =
            WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
601

Chao Liu's avatar
Chao Liu committed
602
603
604
605
        using type =
            typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
                type;
    };
Chao Liu's avatar
Chao Liu committed
606

Chao Liu's avatar
Chao Liu committed
607
608
609
610
611
    template <typename X2Y, typename WorkingY2X, index_t XBegin>
    struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
    {
        using type = WorkingY2X;
    };
Chao Liu's avatar
Chao Liu committed
612
613

    using type =
Chao Liu's avatar
Chao Liu committed
614
615
        typename sequence_map_inverse_impl<SeqMap,
                                           typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
Chao Liu's avatar
Chao Liu committed
616
                                           0,
Chao Liu's avatar
Chao Liu committed
617
                                           SeqMap::Size()>::type;
Chao Liu's avatar
Chao Liu committed
618
619
};

620
621
622
623
624
625
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
    return ((Xs == Ys) && ...);
}

Chao Liu's avatar
Chao Liu committed
626
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
627
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
628
629
630
631
632
633
634
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
635
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
636
637
638
639
640
641
642
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
643
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
644
645
646
647
648
649
650
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
651
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
652
653
654
655
656
657
658
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
659
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
660
661
662
663
664
665
666
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
667
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
668
{
Chao Liu's avatar
Chao Liu committed
669
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
670
671
672
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
673
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
674
{
Chao Liu's avatar
Chao Liu committed
675
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
676
677
678
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
679
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
680
{
Chao Liu's avatar
Chao Liu committed
681
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
682
683
684
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
685
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
686
{
Chao Liu's avatar
Chao Liu committed
687
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
688
689
690
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
691
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
692
{
Chao Liu's avatar
Chao Liu committed
693
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
694
695
}

Chao Liu's avatar
Chao Liu committed
696
697
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
698
{
Chao Liu's avatar
Chao Liu committed
699
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
700
701
}

Chao Liu's avatar
Chao Liu committed
702
703
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
704
{
Chao Liu's avatar
Chao Liu committed
705
    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
706
707
}

Chao Liu's avatar
Chao Liu committed
708
709
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
710
{
Chao Liu's avatar
Chao Liu committed
711
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
712
713
}

Chao Liu's avatar
Chao Liu committed
714
715
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
716
{
Chao Liu's avatar
Chao Liu committed
717
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
718
719
}

Chao Liu's avatar
Chao Liu committed
720
721
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
722
{
Chao Liu's avatar
Chao Liu committed
723
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
724
725
}

726
727
728
729
730
731
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
732
template <typename Seq>
Chao Liu's avatar
Chao Liu committed
733
__host__ __device__ constexpr auto sequence_pop_back(Seq)
734
{
Chao Liu's avatar
Chao Liu committed
735
    static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
736
    return sequence_pop_front(Seq::Reverse()).Reverse();
737
}
738

Chao Liu's avatar
Chao Liu committed
739
740
741
742
743
744
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
    return typename sequence_merge<Seqs...>::type{};
}

745
746
747
748
749
750
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
751
template <typename F, index_t... Xs, index_t... Ys>
752
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
753
{
754
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
755
756
757
758

    return Sequence<f(Xs, Ys)...>{};
}

Chao Liu's avatar
Chao Liu committed
759
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
760
761
762
763
764
765
766
767
768
769
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
770
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
771
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
772
{
Chao Liu's avatar
Chao Liu committed
773
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
774
775
}

Chao Liu's avatar
Chao Liu committed
776
777
778
779
780
781
782
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
    return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
        .PushBack(Number<Init>{});
}

Chao Liu's avatar
Chao Liu committed
783
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
784
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
785
{
Chao Liu's avatar
Chao Liu committed
786
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
787
}
788

Chao Liu's avatar
Chao Liu committed
789
template <typename Seq, index_t... Is>
790
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
Chao Liu's avatar
Chao Liu committed
791
792
793
794
{
    return Sequence<Seq::At(Number<Is>{})...>{};
}

Chao Liu's avatar
Chao Liu committed
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
    using new_work_seq = typename conditional<RemainMask::Front(),
                                              decltype(WorkSeq::PushBack(RemainSeq::Front())),
                                              WorkSeq>::type;

    using type =
        typename pick_sequence_elements_by_mask_impl<new_work_seq,
                                                     decltype(RemainSeq::PopFront()),
                                                     decltype(RemainMask::PopFront())>::type;
};

template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
};

} // namespace detail

818
819
820
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
Chao Liu's avatar
Chao Liu committed
821
822
823
    static_assert(Seq::Size() == Mask::Size(), "wrong!");

    return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
824
825
}

Chao Liu's avatar
Chao Liu committed
826
827
828
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
Chao Liu's avatar
Chao Liu committed
829
{
Chao Liu's avatar
Chao Liu committed
830
    using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
Chao Liu's avatar
Chao Liu committed
831

Chao Liu's avatar
Chao Liu committed
832
833
834
835
836
    using type =
        typename modify_sequence_elements_by_ids_impl<new_work_seq,
                                                      decltype(RemainValues::PopFront()),
                                                      decltype(RemainIds::PopFront())>::type;
};
Chao Liu's avatar
Chao Liu committed
837

Chao Liu's avatar
Chao Liu committed
838
839
840
841
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
    using type = WorkSeq;
Chao Liu's avatar
Chao Liu committed
842
};
Chao Liu's avatar
Chao Liu committed
843
844
845
846
847
848
849
850
851
852
} // namespace detail

template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
    static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");

    return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
Chao Liu's avatar
Chao Liu committed
853

Chao Liu's avatar
Chao Liu committed
854
template <typename Seq, typename Reduce, index_t Init>
Chao Liu's avatar
Chao Liu committed
855
__host__ __device__ constexpr index_t
Chao Liu's avatar
Chao Liu committed
856
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
Chao Liu's avatar
Chao Liu committed
857
858
859
{
    index_t result = Init;

Chao Liu's avatar
Chao Liu committed
860
861
862
863
    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        result = f(result, Seq::At(i));
    }
Chao Liu's avatar
Chao Liu committed
864
865
866
867

    return result;
}

Chao Liu's avatar
Chao Liu committed
868
869
// TODO: a generic any_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
870
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
871
872
873
874
875
876
877
878
879
880
881
882
883
{
    bool flag = false;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag || f(Seq::At(i));
    }

    return flag;
}

// TODO: a generic all_of for any container
template <typename Seq, typename F>
Chao Liu's avatar
Chao Liu committed
884
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Chao Liu's avatar
Chao Liu committed
885
886
887
888
889
890
891
892
893
894
895
{
    bool flag = true;

    for(index_t i = 0; i < Seq::Size(); ++i)
    {
        flag = flag && f(Seq::At(i));
    }

    return flag;
}

896
897
898
899
900
901
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;

template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;

902
} // namespace ck
Umang Yadav's avatar
Umang Yadav committed
903
904

#pragma clang diagnostic pop