functional2.hip.hpp 3.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#pragma once
#include "Sequence.hip.hpp"

template <index_t RemainDim>
struct static_ford_impl
{
    // F signature: F(Sequence<...> multi_id)
    // CurrentMultiIndex: Sequence<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
        static_assert(RemainDim > 1, "wrong!");

        constexpr auto next_length = RemainLengths{}.Front();

        static_for<0, next_length, 1>{}([=](auto I) {
            static_ford_impl<RemainDim - 1>{}(
                f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront());
        });
    }
};

template <>
struct static_ford_impl<1>
{
    // F signature: F(Sequence<Is...> multi_id)
    // CurrentMultiIndex: Sequence<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == 1, "wrong!");

        constexpr index_t last_length = RemainLengths{}.Front();

        static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); });
    }
};

// Lengths is Sequence<...>
template <class Lengths>
struct static_ford
{
    // F signature: F(Sequence<Is...> multi_id)
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
        constexpr index_t first_length = Lengths{}.Front();

        static_for<0, first_length, 1>{}([=](auto I) {
            static_ford_impl<Lengths::GetSize() - 1>{}(
                f, Sequence<I.Get()>{}, Lengths{}.PopFront());
        });
    }
};

template <index_t RemainDim>
struct ford_impl
{
    // F signature: F(Array<...> multi_id)
    // CurrentMultiIndex: Array<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void
    operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
        static_assert(RemainDim > 1, "wrong!");

        constexpr auto next_length = RemainLengths{}.Front();

        for(index_t i = 0; i < next_length; ++i)
        {
            ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
        }
    }
};

template <>
struct ford_impl<1>
{
    // F signature: F(Array<...> multi_id)
    // CurrentMultiIndex: Array<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void
    operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == 1, "wrong!");

        constexpr index_t last_length = RemainLengths{}.Front();

        for(index_t i = 0; i < last_length; ++i)
        {
            f(current_multi_id.PushBack(i));
        }
    }
};

// Lengths is Sequence<...>
template <class Lengths>
struct ford
{
    // F signature: F(Array<...> multi_id)
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
        constexpr index_t first_length = Lengths{}.Front();

        for(index_t i = 0; i < first_length; ++i)
        {
            ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
        }
    }
};