Sequence.hip.hpp 7.01 KB
Newer Older
1
2
3
4
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
8
9
struct Sequence
{
    using Type = Sequence<Is...>;

10
    static constexpr index_t mSize = sizeof...(Is);
11

12
13
14
    const index_t mData[mSize] = {Is...};

    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
15

Chao Liu's avatar
Chao Liu committed
16
17
    template <index_t I>
    __host__ __device__ constexpr index_t Get(Number<I>) const
18
19
20
21
    {
        return mData[I];
    }

22
23
    __host__ __device__ index_t operator[](index_t i) const { return mData[i]; }

24
25
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) const
26
    {
27
        static_assert(mSize == sizeof...(IRs), "mSize not consistent");
28

29
        constexpr auto old = Type{};
30

31
        return Sequence<old.Get(Number<IRs>{})...>{};
32
33
    }

34
35
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
36
    {
37
        // TODO: don't know how to implement this
38
        printf("Sequence::ReorderGivenOld2New not implemented");
39
40
41
        assert(false);
    }

42
43
44
45
    __host__ __device__ constexpr index_t Front() const { return mData[0]; }

    __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; }

46
47
48
49
50
51
    template <index_t I>
    __host__ __device__ constexpr auto PushFront(Number<I>) const
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
52
    template <index_t I>
53
54
55
56
57
    __host__ __device__ constexpr auto PushBack(Number<I>) const
    {
        return Sequence<Is..., I>{};
    }

58
59
    __host__ __device__ constexpr auto PopFront() const;

60
61
62
63
64
65
66
67
68
    __host__ __device__ constexpr auto PopBack() const;

    template <class F>
    __host__ __device__ constexpr auto Transform(F f) const
    {
        return Sequence<f(Is)...>{};
    }
};

69
70
71
72
73
74
75
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
    return Sequence<Is...>{};
}

76
77
78
#if 0
// TODO: for some reason, compiler cannot instantiate this template
template <index_t I, index_t... Is>
79
80
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
81
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
82
83
    return Sequence<Is...>{};
}
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#else
// TODO: delete these very ugly mess
template <index_t I0, index_t I1>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
{
    return Sequence<I0>{};
}

template <index_t I0, index_t I1, index_t I2>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
{
    return Sequence<I0, I1>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
{
    return Sequence<I0, I1, I2>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
{
    return Sequence<I0, I1, I2, I3>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
{
    return Sequence<I0, I1, I2, I3, I4>{};
}

template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
{
    return Sequence<I0, I1, I2, I3, I4, I5>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7,
          index_t I8>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
}

template <index_t I0,
          index_t I1,
          index_t I2,
          index_t I3,
          index_t I4,
          index_t I5,
          index_t I6,
          index_t I7,
          index_t I8,
          index_t I9>
__host__ __device__ constexpr auto
    sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
{
    return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{};
}
#endif
165

166
#if 1
167
// TODO: fix these mess
Chao Liu's avatar
Chao Liu committed
168
template <class F, index_t... Xs, index_t... Ys>
169
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
170
{
171
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
172
173
174
175

    return Sequence<f(Xs, Ys)...>{};
}

176
177
178
179
180
181
182
183
184
185
186
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}
#else
187
// TODO:: these doesn't compile
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
template <index_t NRemain>
struct transform_sequences_impl
{
    template <class F, class Y, class... Xs>
    __host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const
    {
        static_assert(NRemain > 1, "wrong! should have NRemain > 1");

        constexpr index_t N  = f(Xs{}.Get(Number<0>{})...);
        constexpr auto y_new = y.PushBack(Number<N>{});

        return transform_sequences_impl<NRemain - 1>{}(f, y_new, xs.PopFront()...);
    }
};

template <>
struct transform_sequences_impl<1>
205
{
206
207
    template <class F, class Y, class... Xs>
    __host__ __device__ constexpr auto operator()(F f, Y, Xs...) const
208
    {
209
210
211
212
        constexpr index_t N = f(Xs{}.Get(Number<0>{})...);
        return Y{}.PushBack(Number<N>{});
    }
};
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
template <class F, class X, class... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs)
{
    constexpr index_t nSize = X::GetSize();
    constexpr auto I0       = Number<0>{};

    constexpr auto y0 = Sequence<f(X{}.Get(I0), Xs{}.Get(I0)...)>{};

    return transform_sequences_impl<nSize - 1>{}(f, y0, x.PopFront(), xs.PopFront()...);
}
#endif

template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront() const
{
    return sequence_pop_front(Type{});
230
231
}

Chao Liu's avatar
Chao Liu committed
232
template <index_t... Is>
233
234
235
236
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
    return sequence_pop_back(Type{});
}
237
238
239
240
241
242
243
244
245
246
247
248

template <class Seq>
struct accumulate_on_sequence_f
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
Chao Liu's avatar
Chao Liu committed
249
250
__host__ __device__ constexpr index_t
    accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
251
252
{
    constexpr index_t a =
253
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
254
255
    return Reduce{}(a, I);
}