"src/include/threadwise_gemm.hpp" did not exist on "b2888adfbe103ae3d9006af87d5871b69cbf00ba"
Sequence.hpp 15 KB
Newer Older
1
2
3
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "integral_constant.hpp"
#include "functional.hpp"
6

7
8
namespace ck {

9
10
11
12
13
14
template <index_t...>
struct Sequence;

template <class Seq, index_t I>
struct sequence_split;

Chao Liu's avatar
Chao Liu committed
15
template <class>
16
struct sequence_reverse;
Chao Liu's avatar
Chao Liu committed
17

Chao Liu's avatar
Chao Liu committed
18
19
20
template <class>
struct sequence_map_inverse;

21
22
23
24
25
26
27
28
29
template <class>
struct is_valid_sequence_map;

template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);

template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);

Chao Liu's avatar
Chao Liu committed
30
template <index_t... Is>
31
32
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
33
34
    using Type      = Sequence;
    using data_type = index_t;
35

36
    static constexpr index_t mSize = sizeof...(Is);
37

Chao Liu's avatar
Chao Liu committed
38
    __host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
39

Chao Liu's avatar
Chao Liu committed
40
    __host__ __device__ static constexpr index_t GetImpl(index_t I)
41
    {
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
        // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[I];
    }

    template <index_t I>
Chao Liu's avatar
Chao Liu committed
48
    __host__ __device__ static constexpr auto Get(Number<I>)
Chao Liu's avatar
Chao Liu committed
49
    {
Chao Liu's avatar
Chao Liu committed
50
51
52
        static_assert(I < mSize, "wrong! I too large");

        return Number<GetImpl(Number<I>{})>{};
Chao Liu's avatar
Chao Liu committed
53
54
    }

Chao Liu's avatar
Chao Liu committed
55
56
    __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }

Chao Liu's avatar
Chao Liu committed
57
58
    template <index_t I>
    __host__ __device__ constexpr auto operator[](Number<I>) const
Chao Liu's avatar
Chao Liu committed
59
    {
Chao Liu's avatar
Chao Liu committed
60
        return Get(Number<I>{});
61
62
    }

Chao Liu's avatar
Chao Liu committed
63
64
65
    // make sure I is constepxr if you want a constexpr return type
    __host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }

66
    template <index_t... IRs>
67
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
68
    {
Chao Liu's avatar
Chao Liu committed
69
        static_assert(sizeof...(Is) == sizeof...(IRs),
Chao Liu's avatar
Chao Liu committed
70
                      "wrong! reorder map should have the same size as Sequence to be rerodered");
Chao Liu's avatar
Chao Liu committed
71

Chao Liu's avatar
Chao Liu committed
72
73
74
        static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");

        return Sequence<Type::Get(Number<IRs>{})...>{};
75
76
    }

Chao Liu's avatar
Chao Liu committed
77
78
79
80
81
82
83
84
85
86
87
88
    // MapOld2New is Sequence<...>
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
    {
        static_assert(MapOld2New::GetSize() == GetSize(),
                      "wrong! reorder map should have the same size as Sequence to be rerodered");

        static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");

        return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
    }

89
90
91
92
    __host__ __device__ static constexpr auto Reverse()
    {
        return typename sequence_reverse<Type>::type{};
    }
Chao Liu's avatar
Chao Liu committed
93

Chao Liu's avatar
Chao Liu committed
94
    __host__ __device__ static constexpr auto Front()
95
    {
Chao Liu's avatar
Chao Liu committed
96
97
        static_assert(mSize > 0, "wrong!");
        return Get(Number<0>{});
98
    }
99

Chao Liu's avatar
Chao Liu committed
100
    __host__ __device__ static constexpr auto Back()
101
    {
Chao Liu's avatar
Chao Liu committed
102
103
        static_assert(mSize > 0, "wrong!");
        return Get(Number<mSize - 1>{});
104
    }
105

106
    __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
Chao Liu's avatar
Chao Liu committed
107

108
    __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
Chao Liu's avatar
Chao Liu committed
109
110
111

    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
112
    {
Chao Liu's avatar
Chao Liu committed
113
        return Sequence<Xs..., Is...>{};
114
115
    }

Chao Liu's avatar
Chao Liu committed
116
117
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushFront(Number<Xs>...)
118
    {
Chao Liu's avatar
Chao Liu committed
119
        return Sequence<Xs..., Is...>{};
120
121
    }

Chao Liu's avatar
Chao Liu committed
122
123
124
125
126
    template <index_t... Xs>
    __host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
    {
        return Sequence<Is..., Xs...>{};
    }
127

Chao Liu's avatar
Chao Liu committed
128
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
129
    __host__ __device__ static constexpr auto PushBack(Number<Xs>...)
130
    {
Chao Liu's avatar
Chao Liu committed
131
132
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
133

Chao Liu's avatar
Chao Liu committed
134
    template <index_t... Ns>
135
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
136
    {
Chao Liu's avatar
Chao Liu committed
137
        return Sequence<Type::Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
138
    }
Chao Liu's avatar
Chao Liu committed
139

Chao Liu's avatar
Chao Liu committed
140
    template <index_t... Ns>
141
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
142
    {
Chao Liu's avatar
Chao Liu committed
143
        return Sequence<Type::Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
144
    }
145
146

    template <index_t I, index_t X>
147
148
149
150
151
152
153
154
155
156
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
    {
        static_assert(I < GetSize(), "wrong!");

        using seq_split          = sequence_split<Type, I>;
        constexpr auto seq_left  = typename seq_split::SeqType0{};
        constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();

        return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
    }
Chao Liu's avatar
Chao Liu committed
157
158
159
160
161
162

    template <class F>
    __host__ __device__ static constexpr auto Transform(F f)
    {
        return Sequence<f(Is)...>{};
    }
163
164
};

Chao Liu's avatar
Chao Liu committed
165
// merge sequence
Chao Liu's avatar
Chao Liu committed
166
167
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
168

Chao Liu's avatar
Chao Liu committed
169
170
171
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
172
    using type = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
173
};
Chao Liu's avatar
Chao Liu committed
174

Chao Liu's avatar
Chao Liu committed
175
176
177
// generate sequence
template <index_t IBegin, index_t NRemain, class F>
struct sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
178
{
Chao Liu's avatar
Chao Liu committed
179
180
181
    static constexpr index_t NRemainLeft  = NRemain / 2;
    static constexpr index_t NRemainRight = NRemain - NRemainLeft;
    static constexpr index_t IMiddle      = IBegin + NRemainLeft;
Chao Liu's avatar
Chao Liu committed
182

Chao Liu's avatar
Chao Liu committed
183
184
185
    using type =
        typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
                                typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
Chao Liu's avatar
Chao Liu committed
186
187
};

Chao Liu's avatar
Chao Liu committed
188
189
template <index_t I, class F>
struct sequence_gen_impl<I, 1, F>
Chao Liu's avatar
Chao Liu committed
190
{
Chao Liu's avatar
Chao Liu committed
191
192
    static constexpr index_t Is = F{}(Number<I>{});
    using type                  = Sequence<Is>;
Chao Liu's avatar
Chao Liu committed
193
};
Chao Liu's avatar
Chao Liu committed
194

Chao Liu's avatar
Chao Liu committed
195
196
template <index_t I, class F>
struct sequence_gen_impl<I, 0, F>
Chao Liu's avatar
Chao Liu committed
197
{
Chao Liu's avatar
Chao Liu committed
198
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
199
200
};

Chao Liu's avatar
Chao Liu committed
201
202
203
204
205
206
207
template <index_t NSize, class F>
struct sequence_gen
{
    using type = typename sequence_gen_impl<0, NSize, F>::type;
};

// arithmetic sequence
Chao Liu's avatar
Chao Liu committed
208
template <index_t IBegin, index_t IEnd, index_t Increment>
209
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
210
{
Chao Liu's avatar
Chao Liu committed
211
212
213
214
215
216
217
218
219
    struct F
    {
        __host__ __device__ constexpr index_t operator()(index_t i) const
        {
            return i * Increment + IBegin;
        }
    };

    using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
Chao Liu's avatar
Chao Liu committed
220
221
222
223
224
225
};

// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
Chao Liu's avatar
Chao Liu committed
226
    struct F
Chao Liu's avatar
Chao Liu committed
227
228
229
230
    {
        __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
    };

Chao Liu's avatar
Chao Liu committed
231
    using type = typename sequence_gen<NSize, F>::type;
Chao Liu's avatar
Chao Liu committed
232
233
234
};

// reverse inclusive scan (with init) sequence
Chao Liu's avatar
Chao Liu committed
235
template <class, class, index_t>
Chao Liu's avatar
Chao Liu committed
236
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
237

Chao Liu's avatar
Chao Liu committed
238
239
template <index_t I, index_t... Is, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
240
{
Chao Liu's avatar
Chao Liu committed
241
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
Chao Liu's avatar
Chao Liu committed
242
243
244

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

Chao Liu's avatar
Chao Liu committed
245
    using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
Chao Liu's avatar
Chao Liu committed
246
247
};

Chao Liu's avatar
Chao Liu committed
248
249
template <index_t I, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
250
{
Chao Liu's avatar
Chao Liu committed
251
    using type = Sequence<Reduce{}(I, Init)>;
Chao Liu's avatar
Chao Liu committed
252
253
};

Chao Liu's avatar
Chao Liu committed
254
255
template <class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
Chao Liu's avatar
Chao Liu committed
256
{
Chao Liu's avatar
Chao Liu committed
257
    using type = Sequence<>;
Chao Liu's avatar
Chao Liu committed
258
259
};

Chao Liu's avatar
Chao Liu committed
260
// split sequence
Chao Liu's avatar
Chao Liu committed
261
262
263
264
265
template <class Seq, index_t I>
struct sequence_split
{
    static constexpr index_t NSize = Seq{}.GetSize();

Chao Liu's avatar
Chao Liu committed
266
267
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
Chao Liu's avatar
Chao Liu committed
268

Chao Liu's avatar
Chao Liu committed
269
270
    using SeqType0 = decltype(Seq::Extract(range0{}));
    using SeqType1 = decltype(Seq::Extract(range1{}));
Chao Liu's avatar
Chao Liu committed
271
272
};

Chao Liu's avatar
Chao Liu committed
273
// reverse sequence
Chao Liu's avatar
Chao Liu committed
274
275
276
277
278
279
template <class Seq>
struct sequence_reverse
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using seq_split = sequence_split<Seq, NSize / 2>;
Chao Liu's avatar
Chao Liu committed
280
281
282
    using type      = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::type,
        typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
Chao Liu's avatar
Chao Liu committed
283
284
285
286
287
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
Chao Liu's avatar
Chao Liu committed
288
    using type = Sequence<I>;
Chao Liu's avatar
Chao Liu committed
289
290
291
292
293
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
Chao Liu's avatar
Chao Liu committed
294
    using type = Sequence<I1, I0>;
Chao Liu's avatar
Chao Liu committed
295
};
Chao Liu's avatar
Chao Liu committed
296

Chao Liu's avatar
Chao Liu committed
297
298
299
template <class Seq>
struct is_valid_sequence_map
{
Chao Liu's avatar
Chao Liu committed
300
    // not implemented yet, always return true
Chao Liu's avatar
Chao Liu committed
301
    static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
Chao Liu's avatar
Chao Liu committed
302
303
304

    // TODO: add proper check for is_valid, something like:
    // static constexpr bool value =
Chao Liu's avatar
Chao Liu committed
305
    //     is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
Chao Liu's avatar
Chao Liu committed
306
    //             typename sequence_sort<Seq>::SortedSeqType>{};
Chao Liu's avatar
Chao Liu committed
307
};
308

Chao Liu's avatar
Chao Liu committed
309
310
311
312
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
    private:
313
314
    static constexpr auto new_y2x =
        WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
Chao Liu's avatar
Chao Liu committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

    public:
    using type =
        typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};

template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
    using type = WorkingY2X;
};

template <class X2Y>
struct sequence_map_inverse
{
    using type =
        typename sequence_map_inverse_impl<X2Y,
                                           typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
                                           0,
                                           X2Y::GetSize()>::type;
};

Chao Liu's avatar
Chao Liu committed
337
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
338
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
339
340
341
342
343
344
345
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
346
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
347
348
349
350
351
352
353
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
354
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
355
356
357
358
359
360
361
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
362
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
363
364
365
366
367
368
369
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
370
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
371
372
373
374
375
376
377
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
378
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
379
{
Chao Liu's avatar
Chao Liu committed
380
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
381
382
383
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
384
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
385
{
Chao Liu's avatar
Chao Liu committed
386
    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
387
388
389
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
390
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
391
{
Chao Liu's avatar
Chao Liu committed
392
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
393
394
395
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
396
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
397
{
Chao Liu's avatar
Chao Liu committed
398
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
399
400
401
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
402
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
403
{
Chao Liu's avatar
Chao Liu committed
404
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
405
406
}

Chao Liu's avatar
Chao Liu committed
407
408
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
409
{
Chao Liu's avatar
Chao Liu committed
410
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
411
412
}

Chao Liu's avatar
Chao Liu committed
413
414
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
415
{
Chao Liu's avatar
Chao Liu committed
416
417
418
    constexpr auto seq_x = Sequence<Xs...>{};

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
419
420
}

Chao Liu's avatar
Chao Liu committed
421
422
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
423
{
Chao Liu's avatar
Chao Liu committed
424
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
425
426
}

Chao Liu's avatar
Chao Liu committed
427
428
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
429
{
Chao Liu's avatar
Chao Liu committed
430
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
431
432
}

Chao Liu's avatar
Chao Liu committed
433
434
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
435
{
Chao Liu's avatar
Chao Liu committed
436
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
437
438
}

439
440
441
442
443
444
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
445
446
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
447
{
448
449
    static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
    return sequence_pop_front(Seq::Reverse()).Reverse();
450
}
451

Chao Liu's avatar
Chao Liu committed
452
453
454
455
456
457
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
458
template <class F, index_t... Xs, index_t... Ys>
459
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
460
{
461
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
462
463
464
465

    return Sequence<f(Xs, Ys)...>{};
}

466
467
468
469
470
471
472
473
474
475
476
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

Chao Liu's avatar
Chao Liu committed
477
478
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
479
{
Chao Liu's avatar
Chao Liu committed
480
    return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
481
482
}

Chao Liu's avatar
Chao Liu committed
483
484
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
485
{
Chao Liu's avatar
Chao Liu committed
486
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
487
}
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
    constexpr index_t nsize = Sequence<Xs...>::GetSize();

    static_assert(nsize <= 10, "wrong!");

    static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });

    static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });

    static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 5>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 6>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 7>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 8>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 9>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });

    static_if<nsize == 10>{}(
        [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
524
525
526

} // namespace ck
#endif