"vscode:/vscode.git/clone" did not exist on "6ae12434d23a7b49eb380d27c262e3f22160f4be"
functional2.hip.hpp 2.97 KB
Newer Older
1
2
3
#pragma once
#include "Sequence.hip.hpp"

Chao Liu's avatar
Chao Liu committed
4
5
// RemainLengths: Sequence<...>
template <class RemainLengths>
6
7
8
9
struct static_ford_impl
{
    // F signature: F(Sequence<...> multi_id)
    // CurrentMultiIndex: Sequence<...>
Chao Liu's avatar
Chao Liu committed
10
11
    template <class F, class CurrentMultiIndex>
    __host__ __device__ void operator()(F f, CurrentMultiIndex) const
12
    {
Chao Liu's avatar
Chao Liu committed
13
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
14

Chao Liu's avatar
Chao Liu committed
15
16
17
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
            static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
                                                                    CurrentMultiIndex::PushBack(I));
18
19
20
21
22
        });
    }
};

template <>
Chao Liu's avatar
Chao Liu committed
23
struct static_ford_impl<Sequence<>>
24
{
Chao Liu's avatar
Chao Liu committed
25
    // F signature: F(Sequence<...> multi_id)
26
    // CurrentMultiIndex: Sequence<...>
Chao Liu's avatar
Chao Liu committed
27
28
    template <class F, class CurrentMultiIndex>
    __host__ __device__ void operator()(F f, CurrentMultiIndex) const
29
    {
Chao Liu's avatar
Chao Liu committed
30
        f(CurrentMultiIndex{});
31
32
33
34
35
36
37
    }
};

// Lengths is Sequence<...>
template <class Lengths>
struct static_ford
{
Chao Liu's avatar
Chao Liu committed
38
    // F signature: F(Sequence<...> multi_id)
39
40
41
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
Chao Liu's avatar
Chao Liu committed
42
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
43

Chao Liu's avatar
Chao Liu committed
44
        static_ford_impl<Lengths>{}(f, Sequence<>{});
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    }
};

template <index_t RemainDim>
struct ford_impl
{
    // F signature: F(Array<...> multi_id)
    // CurrentMultiIndex: Array<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void
    operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
        static_assert(RemainDim > 1, "wrong!");

        constexpr auto next_length = RemainLengths{}.Front();

        for(index_t i = 0; i < next_length; ++i)
        {
            ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
        }
    }
};

template <>
struct ford_impl<1>
{
    // F signature: F(Array<...> multi_id)
    // CurrentMultiIndex: Array<...>
    // RemainLengths: Sequence<...>
    template <class F, class CurrentMultiIndex, class RemainLengths>
    __host__ __device__ void
    operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
    {
        static_assert(RemainLengths::GetSize() == 1, "wrong!");

        constexpr index_t last_length = RemainLengths{}.Front();

        for(index_t i = 0; i < last_length; ++i)
        {
            f(current_multi_id.PushBack(i));
        }
    }
};

// Lengths is Sequence<...>
template <class Lengths>
struct ford
{
    // F signature: F(Array<...> multi_id)
    template <class F>
    __host__ __device__ void operator()(F f) const
    {
        constexpr index_t first_length = Lengths{}.Front();

        for(index_t i = 0; i < first_length; ++i)
        {
            ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
        }
    }
};