functional.hip.hpp 1.84 KB
Newer Older
1
2
3
#pragma once
#include "constant_integral.hip.hpp"

4
5
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
6
7
8
9
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
10
11
        static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
        static_assert(Increment <= Remaining, "will go out-of-range");
12

13
14
        f(Number<Iter>{});
        static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
15
16
17
    }
};

18
19
20
21
22
23
24
25
26
27
28
29
30
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
    template <class F>
    __host__ __device__ void operator()(F) const
    {
        // do nothing
        return;
    }
};

template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
31
32
33
34
{
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
35
36
37
38
        static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
        static_assert((NEnd - NBegin) % Increment == 0,
                      "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
        static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
39
40
41
    }
};

Chao Liu's avatar
Chao Liu committed
42
template <index_t NLoop>
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
struct static_const_reduce_n
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce r) const
    {
        static_assert(NLoop > 1, "out-of-range");

        constexpr auto a = f(Number<NLoop - 1>{});
        auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // cannot use constexpr here, weird
        return r(a, b);
    }
};

template <>
struct static_const_reduce_n<1>
{
    template <class F, class Reduce>
    __host__ __device__ constexpr auto operator()(F f, Reduce) const
    {
        return f(Number<0>{});
    }
};
65
66
67
68
69
70
71

#if 0
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
{
    return [=](auto xs_array){ f(xs...); };
}
72
#endif