Sequence.hip.hpp 4.58 KB
Newer Older
1
2
3
4
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"

Chao Liu's avatar
Chao Liu committed
5
template <index_t... Is>
6
7
8
9
struct Sequence
{
    using Type = Sequence<Is...>;

10
    static constexpr index_t mSize = sizeof...(Is);
11

12
13
14
    const index_t mData[mSize] = {Is...};

    __host__ __device__ static constexpr index_t GetSize() { return mSize; }
15

Chao Liu's avatar
Chao Liu committed
16
17
    template <index_t I>
    __host__ __device__ constexpr index_t Get(Number<I>) const
18
19
20
21
    {
        return mData[I];
    }

22
23
    __host__ __device__ index_t operator[](index_t i) const { return mData[i]; }

24
25
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) const
26
    {
27
        static_assert(mSize == sizeof...(IRs), "mSize not consistent");
28

29
        constexpr auto old = Type{};
30

31
        return Sequence<old.Get(Number<IRs>{})...>{};
32
33
    }

34
35
    template <index_t... IRs>
    __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
36
37
    {
        // don't know how to implement this
38
        printf("Sequence::ReorderGivenOld2New not implemented");
39
40
41
        assert(false);
    }

42
43
44
45
46
47
    template <index_t I>
    __host__ __device__ constexpr auto PushFront(Number<I>) const
    {
        return Sequence<I, Is...>{};
    }

Chao Liu's avatar
Chao Liu committed
48
    template <index_t I>
49
50
51
52
53
    __host__ __device__ constexpr auto PushBack(Number<I>) const
    {
        return Sequence<Is..., I>{};
    }

54
55
    __host__ __device__ constexpr auto PopFront() const;

56
57
58
59
60
61
62
63
64
    __host__ __device__ constexpr auto PopBack() const;

    template <class F>
    __host__ __device__ constexpr auto Transform(F f) const
    {
        return Sequence<f(Is)...>{};
    }
};

65
66
67
68
69
70
71
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
72
template <index_t... Is, index_t I>
73
74
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
75
    static_assert(sizeof...(Is) > 0, "empty Sequence!");
76
77
78
    return Sequence<Is...>{};
}

79
80
#if 1
// this is ugly, only for 2 sequences
Chao Liu's avatar
Chao Liu committed
81
template <class F, index_t... Xs, index_t... Ys>
82
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
83
{
84
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
85
86
87
88

    return Sequence<f(Xs, Ys)...>{};
}

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// this is ugly, only for 3 sequences
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
    static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
                      Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
                  "Dim not the same");

    return Sequence<f(Xs, Ys, Zs)...>{};
}
#else
template <index_t NRemain>
struct transform_sequences_impl
{
    template <class F, class Y, class... Xs>
    __host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const
    {
        static_assert(NRemain > 1, "wrong! should have NRemain > 1");

        constexpr index_t N  = f(Xs{}.Get(Number<0>{})...);
        constexpr auto y_new = y.PushBack(Number<N>{});

        return transform_sequences_impl<NRemain - 1>{}(f, y_new, xs.PopFront()...);
    }
};

template <>
struct transform_sequences_impl<1>
118
{
119
120
    template <class F, class Y, class... Xs>
    __host__ __device__ constexpr auto operator()(F f, Y, Xs...) const
121
    {
122
123
124
125
        constexpr index_t N = f(Xs{}.Get(Number<0>{})...);
        return Y{}.PushBack(Number<N>{});
    }
};
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
template <class F, class X, class... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs)
{
    constexpr index_t nSize = X::GetSize();
    constexpr auto I0       = Number<0>{};

    constexpr auto y0 = Sequence<f(X{}.Get(I0), Xs{}.Get(I0)...)>{};

    return transform_sequences_impl<nSize - 1>{}(f, y0, x.PopFront(), xs.PopFront()...);
}
#endif

template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront() const
{
    return sequence_pop_front(Type{});
143
144
}

Chao Liu's avatar
Chao Liu committed
145
template <index_t... Is>
146
147
148
149
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
    return sequence_pop_back(Type{});
}
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

template <class Seq>
struct accumulate_on_sequence_f
{
    template <class IDim>
    __host__ __device__ constexpr index_t operator()(IDim) const
    {
        return Seq{}.Get(IDim{});
    }
};

template <class Seq, class Reduce, index_t I>
__host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number<I>)
{
    constexpr index_t a =
165
        static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
166
167
    return Reduce{}(a, I);
}