functional.hip.hpp 3.47 KB
Newer Older
1
2
3
#pragma once
#include "constant_integral.hip.hpp"

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
10
11
12
struct forwarder
{
    template <typename T>
    __host__ __device__ constexpr T operator()(T&& x) const
    {
        return std::forward<T>(x);
    }
};

Chao Liu's avatar
Chao Liu committed
13
14
15
16
17
18
19
20
#if 0
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
{
    return [=](auto xs_array){ f(xs...); };
}
#endif

Chao Liu's avatar
Chao Liu committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Emulate compile time if statement for C++14
//   Get the idea from
//   "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
// TODO: use if constexpr, when C++17 is supported
template <bool Predicate>
struct static_if
{
};

template <>
struct static_if<true>
{
    using Type = static_if<true>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F f) const
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
Chao Liu's avatar
Chao Liu committed
40
41
        //   this will make "f" a generic lambda, so that "f" won't be compiled until being
        //   instantiated here
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        f(forwarder{});
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F)
    {
        return Type{};
    }
};

template <>
struct static_if<false>
{
    using Type = static_if<false>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F) const
    {
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F f)
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
Chao Liu's avatar
Chao Liu committed
69
70
        //   this will make "f" a generic lambda, so that "f" won't be compiled until being
        //   instantiated here
Chao Liu's avatar
Chao Liu committed
71
72
73
74
        f(forwarder{});
        return Type{};
    }
};
75
76
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
77
78
79
80
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
81
82
        static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
        static_assert(Increment <= Remaining, "will go out-of-range");
83

84
85
        f(Number<Iter>{});
        static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
86
87
88
    }
};

89
90
91
92
93
94
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
    template <class F>
    __host__ __device__ void operator()(F) const
    {
95
        // no work left, just return
96
97
98
99
        return;
    }
};

Chao Liu's avatar
Chao Liu committed
100
// F signature: F(Number<Iter>)
101
102
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
103
104
105
106
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
107
108
        static_assert((NEnd - NBegin) % Increment == 0,
                      "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
Chao Liu's avatar
Chao Liu committed
109

Chao Liu's avatar
Chao Liu committed
110
        static_if<(NBegin < NEnd)>{}(
Chao Liu's avatar
Chao Liu committed
111
            [&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
112
113
114
    }
};

Chao Liu's avatar
Chao Liu committed
115
template <index_t NLoop>
116
117
struct static_const_reduce_n
{
Chao Liu's avatar
Chao Liu committed
118
    // signature of F: F(Number<I>)
119
120
121
122
123
124
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce r) const
    {
        static_assert(NLoop > 1, "out-of-range");

        constexpr auto a = f(Number<NLoop - 1>{});
125
        auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
126
127
128
129
130
131
132
133
134
135
136
137
138
        return r(a, b);
    }
};

template <>
struct static_const_reduce_n<1>
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce) const
    {
        return f(Number<0>{});
    }
};