Sequence.hip.hpp 13.8 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "integral_constant.hip.hpp"
3
4
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
8
    using Type = Sequence;
9

10
    static constexpr index_t mSize = sizeof...(Is);
11

12
    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
13

Chao Liu's avatar
Chao Liu committed
14
    template <index_t I>
15
    __host__ __device__ static constexpr index_t Get(Number<I>)
16
    {
17
18
19
20
        static_assert(I < mSize, "wrong! I too large");

        // the last dummy element is to prevent compiler complain about empty Sequence
        const index_t mData[mSize + 1] = {Is..., 0};
21
22
23
        return mData[I];
    }

24
    template <index_t... IRs>
25
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
26
    {
27
28
29
30
31
#if 0 // require sequence_sort, which is not implemented yet
        static_assert(is_same<sequence_sort<Sequence<IRs...>>::SortedSeqType,
                              arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
                      "wrong! invalid new2old map");
#endif
32

33
        return Sequence<Type{}.Get(Number<IRs>{})...>{};
34
35
    }

36
37
38
#if 0 // require sequence_sort, which is not implemented yet
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
39
    {
40
41
42
43
44
45
46
        static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
                              arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
                      "wrong! invalid old2new map");

        constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};

        return ReorderGivenNew2Old(map_new2old);
47
    }
48
#endif
49

50
    __host__ __device__ static constexpr auto Reverse();
Chao Liu's avatar
Chao Liu committed
51

52
53
54
55
56
    __host__ __device__ static constexpr index_t Front()
    {
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[0];
    }
57

58
59
60
61
62
    __host__ __device__ static constexpr index_t Back()
    {
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[mSize - 1];
    }
63

64
    template <index_t I>
65
    __host__ __device__ static constexpr auto PushFront(Number<I>)
66
67
68
69
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
70
    template <index_t I>
71
    __host__ __device__ static constexpr auto PushBack(Number<I>)
72
73
74
75
    {
        return Sequence<Is..., I>{};
    }

76
    __host__ __device__ static constexpr auto PopFront();
77

78
    __host__ __device__ static constexpr auto PopBack();
79

Chao Liu's avatar
Chao Liu committed
80
    template <index_t... Xs>
81
    __host__ __device__ static constexpr auto Append(Sequence<Xs...>)
82
    {
Chao Liu's avatar
Chao Liu committed
83
84
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
85

Chao Liu's avatar
Chao Liu committed
86
    template <index_t... Ns>
87
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
88
    {
Chao Liu's avatar
Chao Liu committed
89
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
90
    }
Chao Liu's avatar
Chao Liu committed
91

Chao Liu's avatar
Chao Liu committed
92
    template <index_t... Ns>
93
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
94
    {
Chao Liu's avatar
Chao Liu committed
95
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
96
    }
97
98
99

    template <index_t I, index_t X>
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
100
101
};

Chao Liu's avatar
Chao Liu committed
102
103
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
106
107
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
108
    using SeqType = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
109
};
Chao Liu's avatar
Chao Liu committed
110

Chao Liu's avatar
Chao Liu committed
111
template <index_t IBegin, index_t NSize, index_t Increment>
112
struct arithmetic_sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
113
114
{
    static constexpr index_t NSizeLeft = NSize / 2;
Chao Liu's avatar
Chao Liu committed
115

Chao Liu's avatar
Chao Liu committed
116
    using SeqType = typename sequence_merge<
117
118
        typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
        typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment,
Chao Liu's avatar
Chao Liu committed
119
120
                                              NSize - NSizeLeft,
                                              Increment>::SeqType>::SeqType;
Chao Liu's avatar
Chao Liu committed
121
122
};

Chao Liu's avatar
Chao Liu committed
123
template <index_t IBegin, index_t Increment>
124
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment>
Chao Liu's avatar
Chao Liu committed
125
{
Chao Liu's avatar
Chao Liu committed
126
    using SeqType = Sequence<IBegin>;
Chao Liu's avatar
Chao Liu committed
127
};
Chao Liu's avatar
Chao Liu committed
128

Chao Liu's avatar
Chao Liu committed
129
template <index_t IBegin, index_t Increment>
130
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment>
Chao Liu's avatar
Chao Liu committed
131
{
Chao Liu's avatar
Chao Liu committed
132
133
134
135
    using SeqType = Sequence<>;
};

template <index_t IBegin, index_t IEnd, index_t Increment>
136
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
137
138
{
    using SeqType =
139
        typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
Chao Liu's avatar
Chao Liu committed
140
};
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
143
template <class, class>
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
144

Chao Liu's avatar
Chao Liu committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
template <index_t I, index_t... Is, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
{
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

    using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};

template <index_t I, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
{
    using SeqType = Sequence<I>;
};

Chao Liu's avatar
Chao Liu committed
161
162
163
164
165
166
template <class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce>
{
    using SeqType = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
template <class, class>
struct sequence_extract;

template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
    using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
};

template <class Seq, index_t I>
struct sequence_split
{
    static constexpr index_t NSize = Seq{}.GetSize();

181
182
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::SeqType;
Chao Liu's avatar
Chao Liu committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

    using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
    using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
};

template <class Seq>
struct sequence_reverse
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using seq_split = sequence_split<Seq, NSize / 2>;
    using SeqType   = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
        typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
    using SeqType = Sequence<I>;
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
    using SeqType = Sequence<I1, I0>;
};
Chao Liu's avatar
Chao Liu committed
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#if 0 // not fully implemented
template <class KeySeq0, class ValSeq0, class KeySeq1, class ValSeq1>
struct sequence_sort_merge_impl;

template <index_t Key0,
          index_t... Keys0,
          index_t Val0,
          index_t... Vals0,
          index_t Key1,
          index_t... Keys1,
          index_t Val0,
          index_t... Vals1>
struct sequence_sort_merge_impl<Sequence<Key0, Keys0...>,
                                Sequence<Val0, Vals0...>,
                                Sequence<Key1, Keys1...>,
                                Sequence<Val1, Vals1...>>
{
};

template <class>
struct sequence_sort;

template <index_t... Is>
struct sequence_sort<Sequence<Is...>>
{
    using OriginalSeqType        = Sequence<Is...>;
    using SortedSeqType          = xxxxx;
    using MapSorted2OriginalType = xxx;
};

template <class Seq, class IsValidSeqMap>
struct sequence_map_inverse_impl;

// impl for valid map, no impl for invalid map
template <index_t... Is>
struct sequence_map_inverse_impl<Sequence<Is...>, true>
{
    using SeqMapType = sequence_sort<Sequence<Is...>>::MapSorted2OriginalType;
};

template <class>
struct sequence_map_inverse;

template <class Is...>
struct sequence_map_inverse<Sequence<Is...>>
{
    // TODO: make sure the map to be inversed is valid: [0, sizeof...(Is))
    static constexpr bool is_valid_sequence_map =
        is_same<typename sequence_sort<Sequence<Is...>>::SortedSeqType,
                typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value;

    // make compiler fails, if is_valid_map != true
    using SeqMapType =
        typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
};
#endif

Chao Liu's avatar
Chao Liu committed
268
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
269
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
270
271
272
273
274
275
276
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
277
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y)
Chao Liu's avatar
Chao Liu committed
278
279
280
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

Chao Liu's avatar
Chao Liu committed
281
282
    static_for<0, seq_x.GetSize(), 1>{}(
        [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
Chao Liu's avatar
Chao Liu committed
283
284
285
286
287

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
288
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
289
290
291
292
293
294
295
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
296
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
297
298
299
300
301
302
303
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
304
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
305
306
307
308
309
310
311
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
312
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
313
{
Chao Liu's avatar
Chao Liu committed
314
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
315
316
317
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
318
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
319
{
320
#if 0 // TODO: turn it on. Doesn't compile
Chao Liu's avatar
Chao Liu committed
321
322
323
324
325
326
327
328
329
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
    });
#endif

    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
330
331
332
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
333
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
334
{
Chao Liu's avatar
Chao Liu committed
335
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
336
337
338
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
339
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
340
{
Chao Liu's avatar
Chao Liu committed
341
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
342
343
344
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
345
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
346
{
Chao Liu's avatar
Chao Liu committed
347
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
348
349
}

Chao Liu's avatar
Chao Liu committed
350
351
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
352
{
Chao Liu's avatar
Chao Liu committed
353
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
354
355
}

Chao Liu's avatar
Chao Liu committed
356
357
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
358
{
Chao Liu's avatar
Chao Liu committed
359
360
361
362
363
364
365
366
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
    });

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
367
368
}

Chao Liu's avatar
Chao Liu committed
369
370
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
371
{
Chao Liu's avatar
Chao Liu committed
372
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
373
374
}

Chao Liu's avatar
Chao Liu committed
375
376
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
377
{
Chao Liu's avatar
Chao Liu committed
378
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
379
380
}

Chao Liu's avatar
Chao Liu committed
381
382
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
383
{
Chao Liu's avatar
Chao Liu committed
384
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
385
386
}

387
388
389
390
391
392
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
393
394
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
395
{
396
    static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!");
Chao Liu's avatar
Chao Liu committed
397
    return sequence_pop_front(Seq{}.Reverse()).Reverse();
398
}
399

Chao Liu's avatar
Chao Liu committed
400
401
402
403
404
405
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
406
template <class F, index_t... Xs, index_t... Ys>
407
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
408
{
409
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
410
411
412
413

    return Sequence<f(Xs, Ys)...>{};
}

414
415
416
417
418
419
420
421
422
423
424
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

425
426
template <class Seq, class Reduce>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
427
{
428
    return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
429
430
}

431
432
template <class Seq, class Reduce>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
433
{
434
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
435
}
436
437

template <class Seq>
Chao Liu's avatar
Chao Liu committed
438
struct accumulate_on_sequence_impl
439
440
441
442
443
444
445
446
447
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
Chao Liu's avatar
Chao Liu committed
448
449
__host__ __device__ constexpr index_t
    accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
450
451
{
    constexpr index_t a =
Chao Liu's avatar
Chao Liu committed
452
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_impl<Seq>{}, Reduce{});
453
454
    return Reduce{}(a, I);
}
Chao Liu's avatar
Chao Liu committed
455

Chao Liu's avatar
Chao Liu committed
456
template <index_t... Is>
457
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
Chao Liu's avatar
Chao Liu committed
458
{
459
    return sequence_pop_front(Type{});
Chao Liu's avatar
Chao Liu committed
460
}
Chao Liu's avatar
Chao Liu committed
461

462
463
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
Chao Liu's avatar
Chao Liu committed
464
{
465
    return sequence_pop_back(Type{});
Chao Liu's avatar
Chao Liu committed
466
467
}

468
469
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
Chao Liu's avatar
Chao Liu committed
470
{
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    return typename sequence_reverse<Sequence<Is...>>::SeqType{};
}

template <index_t... Is>
template <index_t I, index_t X>
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
{
    static_assert(I < GetSize(), "wrong!");

    using seq_split          = sequence_split<Type, I>;
    constexpr auto seq_left  = typename seq_split::SeqType0{};
    constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();

    return seq_left.PushBack(Number<X>{}).Append(seq_right);
Chao Liu's avatar
Chao Liu committed
485
}