Sequence.hip.hpp 12.2 KB
Newer Older
1
2
3
4
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
8
    using Type = Sequence;
9

10
    static constexpr index_t mSize = sizeof...(Is);
11

12
13
14
    const index_t mData[mSize] = {Is...};

    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
15

Chao Liu's avatar
Chao Liu committed
16
17
    template <index_t I>
    __host__ __device__ constexpr index_t Get(Number<I>) const
18
19
20
21
    {
        return mData[I];
    }

22
23
    __host__ __device__ index_t operator[](index_t i) const { return mData[i]; }

24
25
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) const
26
    {
27
        static_assert(mSize == sizeof...(IRs), "mSize not consistent");
28

29
        constexpr auto old = Type{};
30

31
        return Sequence<old.Get(Number<IRs>{})...>{};
32
33
    }

34
35
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
36
    {
37
        // TODO: don't know how to implement this
38
        printf("Sequence::ReorderGivenOld2New not implemented");
39
40
41
        assert(false);
    }

Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
    __host__ __device__ constexpr auto Reverse() const
    {
        // not implemented
    }

47
48
49
50
    __host__ __device__ constexpr index_t Front() const { return mData[0]; }

    __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; }

51
52
53
54
55
56
    template <index_t I>
    __host__ __device__ constexpr auto PushFront(Number<I>) const
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
57
    template <index_t I>
58
59
60
61
62
    __host__ __device__ constexpr auto PushBack(Number<I>) const
    {
        return Sequence<Is..., I>{};
    }

63
64
    __host__ __device__ constexpr auto PopFront() const;

65
66
    __host__ __device__ constexpr auto PopBack() const;

Chao Liu's avatar
Chao Liu committed
67
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
68
    __host__ __device__ constexpr auto Append(Sequence<Xs...>) const
69
    {
Chao Liu's avatar
Chao Liu committed
70
71
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
72

Chao Liu's avatar
Chao Liu committed
73
74
75
    template <index_t... Ns>
    __host__ __device__ constexpr auto Extract(Number<Ns>...) const
    {
Chao Liu's avatar
Chao Liu committed
76
        return Sequence<Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
77
    }
Chao Liu's avatar
Chao Liu committed
78

Chao Liu's avatar
Chao Liu committed
79
80
    template <index_t... Ns>
    __host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
Chao Liu's avatar
Chao Liu committed
81
    {
Chao Liu's avatar
Chao Liu committed
82
        return Sequence<Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
83
    }
84
85
};

Chao Liu's avatar
Chao Liu committed
86
87
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
88

Chao Liu's avatar
Chao Liu committed
89
90
91
92
93
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
    using Type = Sequence<Xs..., Ys...>;
};
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
96
97
98
template <index_t IBegin, index_t NSize, index_t Increment>
struct increasing_sequence_gen
{
    static constexpr index_t NSizeLeft = NSize / 2;
Chao Liu's avatar
Chao Liu committed
99

Chao Liu's avatar
Chao Liu committed
100
101
102
103
104
    using Type =
        sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
                       typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
                                                        NSize - NSizeLeft,
                                                        Increment>::Type>;
Chao Liu's avatar
Chao Liu committed
105
106
};

Chao Liu's avatar
Chao Liu committed
107
108
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 1, Increment>
Chao Liu's avatar
Chao Liu committed
109
{
Chao Liu's avatar
Chao Liu committed
110
111
    using Type = Sequence<IBegin>;
};
Chao Liu's avatar
Chao Liu committed
112

Chao Liu's avatar
Chao Liu committed
113
114
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 0, Increment>
Chao Liu's avatar
Chao Liu committed
115
{
Chao Liu's avatar
Chao Liu committed
116
117
    using Type = Sequence<>;
};
Chao Liu's avatar
Chao Liu committed
118

Chao Liu's avatar
Chao Liu committed
119
template <index_t IBegin, index_t IEnd, index_t Increment>
Chao Liu's avatar
Chao Liu committed
120
121
__host__ __device__ constexpr auto
    make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
Chao Liu's avatar
Chao Liu committed
122
{
Chao Liu's avatar
Chao Liu committed
123
124
125
    static_assert(IBegin <= IEnd && Increment > 0, "wrong!");

    constexpr index_t NSize = (IEnd - IBegin) / Increment;
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
    return increasing_sequence_gen<IBegin, NSize, Increment>{};
Chao Liu's avatar
Chao Liu committed
128
129
130
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
131
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
132
133
134
135
136
137
138
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
139
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y)
Chao Liu's avatar
Chao Liu committed
140
141
142
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

Chao Liu's avatar
Chao Liu committed
143
144
    static_for<0, seq_x.GetSize(), 1>{}(
        [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
Chao Liu's avatar
Chao Liu committed
145
146
147
148
149

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
150
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
151
152
153
154
155
156
157
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
158
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
159
160
161
162
163
164
165
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
166
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
167
168
169
170
171
172
173
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
174
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
175
{
Chao Liu's avatar
Chao Liu committed
176
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
177
178
179
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
180
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
181
{
Chao Liu's avatar
Chao Liu committed
182
183
    constexpr auto seq_x = Sequence<Xs...>{};

Chao Liu's avatar
Chao Liu committed
184
#if 0 // doesn't compile
Chao Liu's avatar
Chao Liu committed
185
186
187
188
189
190
191
    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
    });
#endif

    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
192
193
194
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
195
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
196
{
Chao Liu's avatar
Chao Liu committed
197
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
198
199
200
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
201
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
202
{
Chao Liu's avatar
Chao Liu committed
203
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
204
205
206
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
207
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
208
{
Chao Liu's avatar
Chao Liu committed
209
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
210
211
}

Chao Liu's avatar
Chao Liu committed
212
213
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
214
{
Chao Liu's avatar
Chao Liu committed
215
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
216
217
}

Chao Liu's avatar
Chao Liu committed
218
219
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
220
{
Chao Liu's avatar
Chao Liu committed
221
222
223
224
225
226
227
228
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
    });

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
229
230
}

Chao Liu's avatar
Chao Liu committed
231
232
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
233
{
Chao Liu's avatar
Chao Liu committed
234
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
235
236
}

Chao Liu's avatar
Chao Liu committed
237
238
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
239
{
Chao Liu's avatar
Chao Liu committed
240
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
241
242
}

Chao Liu's avatar
Chao Liu committed
243
244
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
245
{
Chao Liu's avatar
Chao Liu committed
246
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
247
248
}

249
250
251
252
253
254
255
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
    return Sequence<Is...>{};
}

256
257
#if 0
// TODO: for some reason, compiler cannot instantiate this template
Chao Liu's avatar
Chao Liu committed
258
template <index_t... Is, index_t I>
259
260
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
261
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
262
263
    return Sequence<Is...>{};
}
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
#else
// TODO: delete these very ugly mess
template <index_t I0, index_t I1>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
{
    return Sequence<I0>{};
}

template <index_t I0, index_t I1, index_t I2>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
{
    return Sequence<I0, I1>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
{
    return Sequence<I0, I1, I2>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
{
    return Sequence<I0, I1, I2, I3>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
{
    return Sequence<I0, I1, I2, I3, I4>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
{
    return Sequence<I0, I1, I2, I3, I4, I5>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7,
          index_t I8>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7,
          index_t I8,
          index_t I9>
__host__ __device__ constexpr auto
    sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{};
}
#endif
345

Chao Liu's avatar
Chao Liu committed
346
347
348
349
350
351
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
352
template <class F, index_t... Xs, index_t... Ys>
353
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
354
{
355
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
356
357
358
359

    return Sequence<f(Xs, Ys)...>{};
}

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront() const
{
    return sequence_pop_front(Type{});
375
376
}

Chao Liu's avatar
Chao Liu committed
377
template <index_t... Is>
378
379
380
381
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
    return sequence_pop_back(Type{});
}
382
383

template <class Seq>
Chao Liu's avatar
Chao Liu committed
384
struct accumulate_on_sequence_impl
385
386
387
388
389
390
391
392
393
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
Chao Liu's avatar
Chao Liu committed
394
395
__host__ __device__ constexpr index_t
    accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
396
397
{
    constexpr index_t a =
Chao Liu's avatar
Chao Liu committed
398
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_impl<Seq>{}, Reduce{});
399
400
    return Reduce{}(a, I);
}
Chao Liu's avatar
Chao Liu committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436

template <index_t NRemain>
struct scan_sequence_impl
{
    template <class ScanedSeq, class RemainSeq, class Reduce>
    __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const
    {
        static_assert(RemainSeq{}.GetSize() == NRemain,
                      "wrong! RemainSeq and NRemain not consistent!");

        constexpr index_t a       = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front());
        constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number<a>{});

        static_if<(NRemain > 1)>{}([&](auto fwd) {
            return scan_sequence_impl<NRemain - 1>{}(
                scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{}));
        }).else_([&](auto fwd) { return fwd(scaned_seq); });
    }
};

template <class Seq, class Reduce>
__host__ __device__ constexpr auto scan_sequence(Seq, Reduce)
{
    constexpr auto scaned_seq = Sequence<Seq{}.front()>{};
    constexpr auto remain_seq = Seq{}.PopFront();

    constexpr index_t remain_size = Seq::GetSize() - 1;

    return scan_sequence_impl<remain_size>{}(scaned_seq, remain_seq, Reduce{});
}

template <class Seq, class Reduce>
__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce)
{
    return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse();
}