functional.hip.hpp 3.41 KB
Newer Older
1
2
3
#pragma once
#include "constant_integral.hip.hpp"

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
10
11
12
struct forwarder
{
    template <typename T>
    __host__ __device__ constexpr T operator()(T&& x) const
    {
        return std::forward<T>(x);
    }
};

Chao Liu's avatar
Chao Liu committed
13
14
15
16
17
18
19
20
#if 0
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
{
    return [=](auto xs_array){ f(xs...); };
}
#endif

Chao Liu's avatar
Chao Liu committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Emulate compile time if statement for C++14
//   Get the idea from
//   "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
// TODO: use if constexpr, when C++17 is supported
template <bool Predicate>
struct static_if
{
};

template <>
struct static_if<true>
{
    using Type = static_if<true>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F f) const
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
        //   this will make "f" a generic lambda, so that "f" won't be compiled until here
        f(forwarder{});
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F)
    {
        return Type{};
    }
};

template <>
struct static_if<false>
{
    using Type = static_if<false>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F) const
    {
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F f)
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
        //   this will make "f" a generic lambda, so that "f" won't be compiled until here
        f(forwarder{});
        return Type{};
    }
};
73
74
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
75
76
77
78
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
79
80
        static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
        static_assert(Increment <= Remaining, "will go out-of-range");
81

82
83
        f(Number<Iter>{});
        static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
84
85
86
    }
};

87
88
89
90
91
92
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
    template <class F>
    __host__ __device__ void operator()(F) const
    {
93
        // no work left, just return
94
95
96
97
        return;
    }
};

Chao Liu's avatar
Chao Liu committed
98
// F signature: F(Number<Iter>)
99
100
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
101
102
103
104
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
105
106
        static_assert((NEnd - NBegin) % Increment == 0,
                      "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
Chao Liu's avatar
Chao Liu committed
107

Chao Liu's avatar
Chao Liu committed
108
109
        static_if<(NBegin < End)>{}(
            [&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
110
111
112
    }
};

Chao Liu's avatar
Chao Liu committed
113
template <index_t NLoop>
114
115
struct static_const_reduce_n
{
Chao Liu's avatar
Chao Liu committed
116
    // signature of F: F(Number<I>)
117
118
119
120
121
122
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce r) const
    {
        static_assert(NLoop > 1, "out-of-range");

        constexpr auto a = f(Number<NLoop - 1>{});
123
        auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
124
125
126
127
128
129
130
131
132
133
134
135
136
        return r(a, b);
    }
};

template <>
struct static_const_reduce_n<1>
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce) const
    {
        return f(Number<0>{});
    }
};