functional.hip.hpp 3.35 KB
Newer Older
1
2
3
#pragma once
#include "constant_integral.hip.hpp"

4
5
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
6
7
8
9
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
10
11
        static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
        static_assert(Increment <= Remaining, "will go out-of-range");
12

13
14
        f(Number<Iter>{});
        static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
15
16
17
    }
};

18
19
20
21
22
23
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
    template <class F>
    __host__ __device__ void operator()(F) const
    {
24
        // no work left, just return
25
26
27
28
29
30
        return;
    }
};

template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
31
32
33
34
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
35
36
37
38
        static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
        static_assert((NEnd - NBegin) % Increment == 0,
                      "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
        static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
39
40
41
    }
};

Chao Liu's avatar
Chao Liu committed
42
template <index_t NLoop>
43
44
45
46
47
48
49
50
struct static_const_reduce_n
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce r) const
    {
        static_assert(NLoop > 1, "out-of-range");

        constexpr auto a = f(Number<NLoop - 1>{});
51
        auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
52
53
54
55
56
57
58
59
60
61
62
63
64
        return r(a, b);
    }
};

template <>
struct static_const_reduce_n<1>
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce) const
    {
        return f(Number<0>{});
    }
};
65
66
67
68
69
70
71

#if 0
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
{
    return [=](auto xs_array){ f(xs...); };
}
72
#endif
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

struct forwarder
{
    template <typename T>
    __host__ __device__ constexpr T operator()(T&& x) const
    {
        return std::forward<T>(x);
    }
};

// Emulate compile time if statement for C++14
//   Get the idea from
//   "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
// TODO: use if constexpr, when C++17 is supported
template <bool Predicate>
struct static_if
{
};

template <>
struct static_if<true>
{
    using Type = static_if<true>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F f) const
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
        //   this will make "f" a generic lambda, so that "f" won't be compiled until here
        f(forwarder{});
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F)
    {
        return Type{};
    }
};

template <>
struct static_if<false>
{
    using Type = static_if<false>;

    template <class F>
    __host__ __device__ constexpr auto operator()(F) const
    {
        return Type{};
    }

    template <class F>
    __host__ __device__ static constexpr auto else_(F f)
    {
        // This is a trick for compiler:
        //   Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
        //   this will make "f" a generic lambda, so that "f" won't be compiled until here
        f(forwarder{});
        return Type{};
    }
};