Sequence.hip.hpp 14.1 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "integral_constant.hip.hpp"
3
4
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
8
    using Type = Sequence;
9

10
    static constexpr index_t mSize = sizeof...(Is);
11

12
    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
13

Chao Liu's avatar
Chao Liu committed
14
    template <index_t I>
15
    __host__ __device__ static constexpr index_t Get(Number<I>)
16
    {
17
18
19
20
        static_assert(I < mSize, "wrong! I too large");

        // the last dummy element is to prevent compiler complain about empty Sequence
        const index_t mData[mSize + 1] = {Is..., 0};
21
22
23
        return mData[I];
    }

24
    template <index_t... IRs>
25
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
26
    {
27
28
29
30
31
#if 0 // require sequence_sort, which is not implemented yet
        static_assert(is_same<sequence_sort<Sequence<IRs...>>::SortedSeqType,
                              arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
                      "wrong! invalid new2old map");
#endif
32

33
        return Sequence<Type{}.Get(Number<IRs>{})...>{};
34
35
    }

36
37
38
#if 0 // require sequence_sort, which is not implemented yet
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
39
    {
40
41
42
43
44
45
46
        static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
                              arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
                      "wrong! invalid old2new map");

        constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};

        return ReorderGivenNew2Old(map_new2old);
47
    }
48
#endif
49

50
    __host__ __device__ static constexpr auto Reverse();
Chao Liu's avatar
Chao Liu committed
51

52
53
54
55
56
    __host__ __device__ static constexpr index_t Front()
    {
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[0];
    }
57

58
59
60
61
62
    __host__ __device__ static constexpr index_t Back()
    {
        const index_t mData[mSize + 1] = {Is..., 0};
        return mData[mSize - 1];
    }
63

64
    template <index_t I>
65
    __host__ __device__ static constexpr auto PushFront(Number<I>)
66
67
68
69
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
70
    template <index_t I>
71
    __host__ __device__ static constexpr auto PushBack(Number<I>)
72
73
74
75
    {
        return Sequence<Is..., I>{};
    }

76
    __host__ __device__ static constexpr auto PopFront();
77

78
    __host__ __device__ static constexpr auto PopBack();
79

Chao Liu's avatar
Chao Liu committed
80
    template <index_t... Xs>
81
    __host__ __device__ static constexpr auto Append(Sequence<Xs...>)
82
    {
Chao Liu's avatar
Chao Liu committed
83
84
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
85

Chao Liu's avatar
Chao Liu committed
86
    template <index_t... Ns>
87
    __host__ __device__ static constexpr auto Extract(Number<Ns>...)
Chao Liu's avatar
Chao Liu committed
88
    {
Chao Liu's avatar
Chao Liu committed
89
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
90
    }
Chao Liu's avatar
Chao Liu committed
91

Chao Liu's avatar
Chao Liu committed
92
    template <index_t... Ns>
93
    __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
Chao Liu's avatar
Chao Liu committed
94
    {
Chao Liu's avatar
Chao Liu committed
95
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
96
    }
97
98
99

    template <index_t I, index_t X>
    __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
100
101
};

Chao Liu's avatar
Chao Liu committed
102
103
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
106
107
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
108
    using SeqType = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
109
};
Chao Liu's avatar
Chao Liu committed
110

Chao Liu's avatar
Chao Liu committed
111
template <index_t IBegin, index_t NSize, index_t Increment>
112
struct arithmetic_sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
113
114
{
    static constexpr index_t NSizeLeft = NSize / 2;
Chao Liu's avatar
Chao Liu committed
115

Chao Liu's avatar
Chao Liu committed
116
    using SeqType = typename sequence_merge<
117
118
        typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
        typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment,
Chao Liu's avatar
Chao Liu committed
119
120
                                              NSize - NSizeLeft,
                                              Increment>::SeqType>::SeqType;
Chao Liu's avatar
Chao Liu committed
121
122
};

Chao Liu's avatar
Chao Liu committed
123
template <index_t IBegin, index_t Increment>
124
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment>
Chao Liu's avatar
Chao Liu committed
125
{
Chao Liu's avatar
Chao Liu committed
126
    using SeqType = Sequence<IBegin>;
Chao Liu's avatar
Chao Liu committed
127
};
Chao Liu's avatar
Chao Liu committed
128

Chao Liu's avatar
Chao Liu committed
129
template <index_t IBegin, index_t Increment>
130
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment>
Chao Liu's avatar
Chao Liu committed
131
{
Chao Liu's avatar
Chao Liu committed
132
133
134
135
    using SeqType = Sequence<>;
};

template <index_t IBegin, index_t IEnd, index_t Increment>
136
struct arithmetic_sequence_gen
Chao Liu's avatar
Chao Liu committed
137
138
{
    using SeqType =
139
        typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
Chao Liu's avatar
Chao Liu committed
140
};
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
143
template <class, class>
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
144

Chao Liu's avatar
Chao Liu committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
template <index_t I, index_t... Is, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
{
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

    using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};

template <index_t I, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
{
    using SeqType = Sequence<I>;
};

Chao Liu's avatar
Chao Liu committed
161
162
163
164
165
166
template <class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce>
{
    using SeqType = Sequence<>;
};

Chao Liu's avatar
Chao Liu committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
template <class, class>
struct sequence_extract;

template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
    using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
};

template <class Seq, index_t I>
struct sequence_split
{
    static constexpr index_t NSize = Seq{}.GetSize();

181
182
    using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType;
    using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::SeqType;
Chao Liu's avatar
Chao Liu committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

    using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
    using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
};

template <class Seq>
struct sequence_reverse
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using seq_split = sequence_split<Seq, NSize / 2>;
    using SeqType   = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
        typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
    using SeqType = Sequence<I>;
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
    using SeqType = Sequence<I1, I0>;
};
Chao Liu's avatar
Chao Liu committed
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#if 0 // not fully implemented
template <class KeySeq0, class ValSeq0, class KeySeq1, class ValSeq1>
struct sequence_sort_merge_impl;

template <index_t Key0,
          index_t... Keys0,
          index_t Val0,
          index_t... Vals0,
          index_t Key1,
          index_t... Keys1,
          index_t Val0,
          index_t... Vals1>
struct sequence_sort_merge_impl<Sequence<Key0, Keys0...>,
                                Sequence<Val0, Vals0...>,
                                Sequence<Key1, Keys1...>,
                                Sequence<Val1, Vals1...>>
{
};

template <class>
struct sequence_sort;

template <index_t... Is>
struct sequence_sort<Sequence<Is...>>
{
    using OriginalSeqType        = Sequence<Is...>;
    using SortedSeqType          = xxxxx;
    using MapSorted2OriginalType = xxx;
};

template <class Seq, class IsValidSeqMap>
struct sequence_map_inverse_impl;

// impl for valid map, no impl for invalid map
template <index_t... Is>
struct sequence_map_inverse_impl<Sequence<Is...>, true>
{
    using SeqMapType = sequence_sort<Sequence<Is...>>::MapSorted2OriginalType;
};

template <class>
struct sequence_map_inverse;

template <class Is...>
struct sequence_map_inverse<Sequence<Is...>>
{
    // TODO: make sure the map to be inversed is valid: [0, sizeof...(Is))
    static constexpr bool is_valid_sequence_map =
        is_same<typename sequence_sort<Sequence<Is...>>::SortedSeqType,
                typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value;

    // make compiler fails, if is_valid_map != true
    using SeqMapType =
        typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
};
Chao Liu's avatar
Chao Liu committed
266

267
#endif
Chao Liu's avatar
Chao Liu committed
268
269
270
271
272
273
274
275
276
277
278
template <class Seq>
struct is_valid_sequence_map
{
    static constexpr bool value =
#if 0 // sequence_sort is not implemented yet
        is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
                typename sequence_sort<Seq>::SortedSeqType>::value;
#else
        true;
#endif
};
279

Chao Liu's avatar
Chao Liu committed
280
template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
281
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
282
283
284
285
286
287
288
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
289
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y)
Chao Liu's avatar
Chao Liu committed
290
291
292
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

Chao Liu's avatar
Chao Liu committed
293
294
    static_for<0, seq_x.GetSize(), 1>{}(
        [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
Chao Liu's avatar
Chao Liu committed
295
296
297
298
299

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
300
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
301
302
303
304
305
306
307
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
308
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
309
310
311
312
313
314
315
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
316
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
317
318
319
320
321
322
323
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
324
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
325
{
Chao Liu's avatar
Chao Liu committed
326
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
327
328
329
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
330
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
331
{
332
#if 0 // TODO: turn it on. Doesn't compile
Chao Liu's avatar
Chao Liu committed
333
334
335
336
337
338
339
340
341
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
    });
#endif

    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
342
343
344
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
345
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
346
{
Chao Liu's avatar
Chao Liu committed
347
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
348
349
350
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
351
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
352
{
Chao Liu's avatar
Chao Liu committed
353
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
354
355
356
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
357
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
358
{
Chao Liu's avatar
Chao Liu committed
359
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
360
361
}

Chao Liu's avatar
Chao Liu committed
362
363
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
364
{
Chao Liu's avatar
Chao Liu committed
365
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
366
367
}

Chao Liu's avatar
Chao Liu committed
368
369
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
370
{
Chao Liu's avatar
Chao Liu committed
371
372
373
374
375
376
377
378
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
    });

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
379
380
}

Chao Liu's avatar
Chao Liu committed
381
382
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
383
{
Chao Liu's avatar
Chao Liu committed
384
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
385
386
}

Chao Liu's avatar
Chao Liu committed
387
388
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
389
{
Chao Liu's avatar
Chao Liu committed
390
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
391
392
}

Chao Liu's avatar
Chao Liu committed
393
394
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
395
{
Chao Liu's avatar
Chao Liu committed
396
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
397
398
}

399
400
401
402
403
404
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
405
406
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
407
{
408
    static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!");
Chao Liu's avatar
Chao Liu committed
409
    return sequence_pop_front(Seq{}.Reverse()).Reverse();
410
}
411

Chao Liu's avatar
Chao Liu committed
412
413
414
415
416
417
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
418
template <class F, index_t... Xs, index_t... Ys>
419
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
420
{
421
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
422
423
424
425

    return Sequence<f(Xs, Ys)...>{};
}

426
427
428
429
430
431
432
433
434
435
436
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

437
438
template <class Seq, class Reduce>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
439
{
440
    return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
441
442
}

443
444
template <class Seq, class Reduce>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
445
{
446
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
447
}
448
449

template <class Seq>
Chao Liu's avatar
Chao Liu committed
450
struct accumulate_on_sequence_impl
451
452
453
454
455
456
457
458
459
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
Chao Liu's avatar
Chao Liu committed
460
461
__host__ __device__ constexpr index_t
    accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
462
463
{
    constexpr index_t a =
Chao Liu's avatar
Chao Liu committed
464
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_impl<Seq>{}, Reduce{});
465
466
    return Reduce{}(a, I);
}
Chao Liu's avatar
Chao Liu committed
467

Chao Liu's avatar
Chao Liu committed
468
template <index_t... Is>
469
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
Chao Liu's avatar
Chao Liu committed
470
{
471
    return sequence_pop_front(Type{});
Chao Liu's avatar
Chao Liu committed
472
}
Chao Liu's avatar
Chao Liu committed
473

474
475
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
Chao Liu's avatar
Chao Liu committed
476
{
477
    return sequence_pop_back(Type{});
Chao Liu's avatar
Chao Liu committed
478
479
}

480
481
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
Chao Liu's avatar
Chao Liu committed
482
{
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    return typename sequence_reverse<Sequence<Is...>>::SeqType{};
}

template <index_t... Is>
template <index_t I, index_t X>
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
{
    static_assert(I < GetSize(), "wrong!");

    using seq_split          = sequence_split<Type, I>;
    constexpr auto seq_left  = typename seq_split::SeqType0{};
    constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();

    return seq_left.PushBack(Number<X>{}).Append(seq_right);
Chao Liu's avatar
Chao Liu committed
497
}