Sequence.hip.hpp 11.2 KB
Newer Older
1
2
3
4
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
struct Sequence
{
Chao Liu's avatar
Chao Liu committed
8
    using Type = Sequence;
9

10
    static constexpr index_t mSize = sizeof...(Is);
11

Chao Liu's avatar
Chao Liu committed
12
13
    const index_t mData[mSize + 1] = {
        Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence
14
15

    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
16

Chao Liu's avatar
Chao Liu committed
17
18
    template <index_t I>
    __host__ __device__ constexpr index_t Get(Number<I>) const
19
20
21
22
    {
        return mData[I];
    }

23
24
    __host__ __device__ index_t operator[](index_t i) const { return mData[i]; }

25
26
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) const
27
    {
28
        static_assert(mSize == sizeof...(IRs), "mSize not consistent");
29

30
        constexpr auto old = Type{};
31

32
        return Sequence<old.Get(Number<IRs>{})...>{};
33
34
    }

35
36
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
37
    {
38
        // TODO: don't know how to implement this
39
        printf("Sequence::ReorderGivenOld2New not implemented");
40
41
42
        assert(false);
    }

Chao Liu's avatar
Chao Liu committed
43
    __host__ __device__ constexpr auto Reverse() const;
Chao Liu's avatar
Chao Liu committed
44

45
46
47
48
    __host__ __device__ constexpr index_t Front() const { return mData[0]; }

    __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; }

49
50
51
52
53
54
    template <index_t I>
    __host__ __device__ constexpr auto PushFront(Number<I>) const
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
55
    template <index_t I>
56
57
58
59
60
    __host__ __device__ constexpr auto PushBack(Number<I>) const
    {
        return Sequence<Is..., I>{};
    }

61
62
    __host__ __device__ constexpr auto PopFront() const;

63
64
    __host__ __device__ constexpr auto PopBack() const;

Chao Liu's avatar
Chao Liu committed
65
    template <index_t... Xs>
Chao Liu's avatar
Chao Liu committed
66
    __host__ __device__ constexpr auto Append(Sequence<Xs...>) const
67
    {
Chao Liu's avatar
Chao Liu committed
68
69
        return Sequence<Is..., Xs...>{};
    }
Chao Liu's avatar
Chao Liu committed
70

Chao Liu's avatar
Chao Liu committed
71
72
73
    template <index_t... Ns>
    __host__ __device__ constexpr auto Extract(Number<Ns>...) const
    {
Chao Liu's avatar
Chao Liu committed
74
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
75
    }
Chao Liu's avatar
Chao Liu committed
76

Chao Liu's avatar
Chao Liu committed
77
78
    template <index_t... Ns>
    __host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
Chao Liu's avatar
Chao Liu committed
79
    {
Chao Liu's avatar
Chao Liu committed
80
        return Sequence<Type{}.Get(Number<Ns>{})...>{};
Chao Liu's avatar
Chao Liu committed
81
    }
82
83
};

Chao Liu's avatar
Chao Liu committed
84
85
template <class, class>
struct sequence_merge;
Chao Liu's avatar
Chao Liu committed
86

Chao Liu's avatar
Chao Liu committed
87
88
89
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
Chao Liu's avatar
Chao Liu committed
90
    using SeqType = Sequence<Xs..., Ys...>;
Chao Liu's avatar
Chao Liu committed
91
};
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
template <index_t IBegin, index_t NSize, index_t Increment>
Chao Liu's avatar
Chao Liu committed
94
struct increasing_sequence_gen_impl
Chao Liu's avatar
Chao Liu committed
95
96
{
    static constexpr index_t NSizeLeft = NSize / 2;
Chao Liu's avatar
Chao Liu committed
97

Chao Liu's avatar
Chao Liu committed
98
99
100
101
102
    using SeqType = typename sequence_merge<
        typename increasing_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
        typename increasing_sequence_gen_impl<IBegin + NSizeLeft * Increment,
                                              NSize - NSizeLeft,
                                              Increment>::SeqType>::SeqType;
Chao Liu's avatar
Chao Liu committed
103
104
};

Chao Liu's avatar
Chao Liu committed
105
template <index_t IBegin, index_t Increment>
Chao Liu's avatar
Chao Liu committed
106
struct increasing_sequence_gen_impl<IBegin, 1, Increment>
Chao Liu's avatar
Chao Liu committed
107
{
Chao Liu's avatar
Chao Liu committed
108
    using SeqType = Sequence<IBegin>;
Chao Liu's avatar
Chao Liu committed
109
};
Chao Liu's avatar
Chao Liu committed
110

Chao Liu's avatar
Chao Liu committed
111
template <index_t IBegin, index_t Increment>
Chao Liu's avatar
Chao Liu committed
112
struct increasing_sequence_gen_impl<IBegin, 0, Increment>
Chao Liu's avatar
Chao Liu committed
113
{
Chao Liu's avatar
Chao Liu committed
114
115
116
117
118
119
120
121
    using SeqType = Sequence<>;
};

template <index_t IBegin, index_t IEnd, index_t Increment>
struct increasing_sequence_gen
{
    using SeqType =
        typename increasing_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
Chao Liu's avatar
Chao Liu committed
122
};
Chao Liu's avatar
Chao Liu committed
123

Chao Liu's avatar
Chao Liu committed
124
template <index_t IBegin, index_t IEnd, index_t Increment>
Chao Liu's avatar
Chao Liu committed
125
126
__host__ __device__ constexpr auto
    make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
Chao Liu's avatar
Chao Liu committed
127
{
Chao Liu's avatar
Chao Liu committed
128
129
    return typename increasing_sequence_gen<IBegin, IEnd, Increment>::SeqType{};
}
Chao Liu's avatar
Chao Liu committed
130

Chao Liu's avatar
Chao Liu committed
131
132
template <class, class>
struct sequence_reverse_inclusive_scan;
Chao Liu's avatar
Chao Liu committed
133

Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
template <index_t I, index_t... Is, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
{
    using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;

    static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());

    using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};

template <index_t I, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
{
    using SeqType = Sequence<I>;
};

template <class, class>
struct sequence_extract;

template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
    using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
};

template <class Seq, index_t I>
struct sequence_split
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType;
    using range1 = typename increasing_sequence_gen<I, NSize, 1>::SeqType;

    using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
    using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
};

template <class Seq>
struct sequence_reverse
{
    static constexpr index_t NSize = Seq{}.GetSize();

    using seq_split = sequence_split<Seq, NSize / 2>;
    using SeqType   = typename sequence_merge<
        typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
        typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
};

template <index_t I>
struct sequence_reverse<Sequence<I>>
{
    using SeqType = Sequence<I>;
};

template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
    using SeqType = Sequence<I1, I0>;
};
Chao Liu's avatar
Chao Liu committed
193
194

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
195
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
196
197
198
199
200
201
202
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs + Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
203
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y)
Chao Liu's avatar
Chao Liu committed
204
205
206
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

Chao Liu's avatar
Chao Liu committed
207
208
    static_for<0, seq_x.GetSize(), 1>{}(
        [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
Chao Liu's avatar
Chao Liu committed
209
210
211
212
213

    return Sequence<(Xs - Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
214
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
215
216
217
218
219
220
221
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs * Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
222
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
223
224
225
226
227
228
229
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs / Ys)...>{};
}

template <index_t... Xs, index_t... Ys>
Chao Liu's avatar
Chao Liu committed
230
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
Chao Liu's avatar
Chao Liu committed
231
232
233
234
235
236
237
{
    static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");

    return Sequence<(Xs % Ys)...>{};
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
238
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
239
{
Chao Liu's avatar
Chao Liu committed
240
    return Sequence<(Xs + Y)...>{};
Chao Liu's avatar
Chao Liu committed
241
242
243
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
244
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
245
{
Chao Liu's avatar
Chao Liu committed
246
#if 0 // doesn't compile
Chao Liu's avatar
Chao Liu committed
247
248
249
250
251
252
253
254
255
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
    });
#endif

    return Sequence<(Xs - Y)...>{};
Chao Liu's avatar
Chao Liu committed
256
257
258
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
259
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
260
{
Chao Liu's avatar
Chao Liu committed
261
    return Sequence<(Xs * Y)...>{};
Chao Liu's avatar
Chao Liu committed
262
263
264
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
265
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
266
{
Chao Liu's avatar
Chao Liu committed
267
    return Sequence<(Xs / Y)...>{};
Chao Liu's avatar
Chao Liu committed
268
269
270
}

template <index_t... Xs, index_t Y>
Chao Liu's avatar
Chao Liu committed
271
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
272
{
Chao Liu's avatar
Chao Liu committed
273
    return Sequence<(Xs % Y)...>{};
Chao Liu's avatar
Chao Liu committed
274
275
}

Chao Liu's avatar
Chao Liu committed
276
277
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
278
{
Chao Liu's avatar
Chao Liu committed
279
    return Sequence<(Y + Xs)...>{};
Chao Liu's avatar
Chao Liu committed
280
281
}

Chao Liu's avatar
Chao Liu committed
282
283
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
284
{
Chao Liu's avatar
Chao Liu committed
285
286
287
288
289
290
291
292
    constexpr auto seq_x = Sequence<Xs...>{};

    static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
        constexpr auto I = decltype(Iter){};
        static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
    });

    return Sequence<(Y - Xs)...>{};
Chao Liu's avatar
Chao Liu committed
293
294
}

Chao Liu's avatar
Chao Liu committed
295
296
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
297
{
Chao Liu's avatar
Chao Liu committed
298
    return Sequence<(Y * Xs)...>{};
Chao Liu's avatar
Chao Liu committed
299
300
}

Chao Liu's avatar
Chao Liu committed
301
302
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
303
{
Chao Liu's avatar
Chao Liu committed
304
    return Sequence<(Y / Xs)...>{};
Chao Liu's avatar
Chao Liu committed
305
306
}

Chao Liu's avatar
Chao Liu committed
307
308
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
Chao Liu's avatar
Chao Liu committed
309
{
Chao Liu's avatar
Chao Liu committed
310
    return Sequence<(Y % Xs)...>{};
Chao Liu's avatar
Chao Liu committed
311
312
}

313
314
315
316
317
318
319
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
320
321
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
322
{
Chao Liu's avatar
Chao Liu committed
323
324
    static_assert(Seq{}.GetSize() > 0, "empty Sequence!");
    return sequence_pop_front(Seq{}.Reverse()).Reverse();
325
}
326

Chao Liu's avatar
Chao Liu committed
327
328
329
330
331
332
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
    return Sequence<f(Xs)...>{};
}

Chao Liu's avatar
Chao Liu committed
333
template <class F, index_t... Xs, index_t... Ys>
334
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
335
{
336
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
337
338
339
340

    return Sequence<f(Xs, Ys)...>{};
}

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}

template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront() const
{
    return sequence_pop_front(Type{});
356
357
}

Chao Liu's avatar
Chao Liu committed
358
template <index_t... Is>
359
360
361
362
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
    return sequence_pop_back(Type{});
}
363
364

template <class Seq>
Chao Liu's avatar
Chao Liu committed
365
struct accumulate_on_sequence_impl
366
367
368
369
370
371
372
373
374
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
Chao Liu's avatar
Chao Liu committed
375
376
__host__ __device__ constexpr index_t
    accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
377
378
{
    constexpr index_t a =
Chao Liu's avatar
Chao Liu committed
379
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_impl<Seq>{}, Reduce{});
380
381
    return Reduce{}(a, I);
}
Chao Liu's avatar
Chao Liu committed
382

Chao Liu's avatar
Chao Liu committed
383
384
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse() const
Chao Liu's avatar
Chao Liu committed
385
{
Chao Liu's avatar
Chao Liu committed
386
387
    return typename sequence_reverse<Sequence<Is...>>::SeqType{};
}
Chao Liu's avatar
Chao Liu committed
388
389

template <class Seq, class Reduce>
Chao Liu's avatar
Chao Liu committed
390
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
Chao Liu's avatar
Chao Liu committed
391
{
Chao Liu's avatar
Chao Liu committed
392
    return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
Chao Liu's avatar
Chao Liu committed
393
394
395
}

template <class Seq, class Reduce>
Chao Liu's avatar
Chao Liu committed
396
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
Chao Liu's avatar
Chao Liu committed
397
{
Chao Liu's avatar
Chao Liu committed
398
    return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
Chao Liu's avatar
Chao Liu committed
399
}