functional3.hpp 4.29 KB
Newer Older
1
2
3
#ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include "functional.hpp"
#include "functional2.hpp"
#include "Sequence.hpp"
#include "Array.hpp"
Chao Liu's avatar
Chao Liu committed
8

9
10
namespace ck {

Chao Liu's avatar
Chao Liu committed
11
namespace detail {
Chao Liu's avatar
tweak  
Chao Liu committed
12

Chao Liu's avatar
Chao Liu committed
13
// RemainLengths: Sequence<...>
14
15
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
16
17
struct static_ford_impl
{
18
    __host__ __device__ constexpr static_ford_impl()
Chao Liu's avatar
Chao Liu committed
19
20
    {
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
21
    }
Chao Liu's avatar
Chao Liu committed
22

23
24
25
26
27
    // F signature: F(Sequence<...>)
    // CurrentOrderedId: Sequence<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
    {
Chao Liu's avatar
Chao Liu committed
28
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
29
30
            static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, CurrentOrderedId::PushBack(I));
Chao Liu's avatar
Chao Liu committed
31
32
33
34
        });
    }
};

35
36
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
37
{
38
39
40
41
    // F signature: F(Sequence<...>)
    // OrderedId: Sequence<...>
    template <class F, class OrderedId>
    __host__ __device__ constexpr void operator()(F f, OrderedId) const
Chao Liu's avatar
Chao Liu committed
42
    {
43
44
        // retrive unordered Id
        f(OrderedId::ReorderGivenOld2New(Orders{}));
Chao Liu's avatar
Chao Liu committed
45
46
47
    }
};

48
49
50
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
51
52
struct ford_impl
{
53
    __host__ __device__ constexpr ford_impl()
Chao Liu's avatar
Chao Liu committed
54
    {
55
56
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
    }
Chao Liu's avatar
Chao Liu committed
57

58
59
60
61
62
63
    // F signature: F(Array<...> multi_id)
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
    {
        for(index_t i = 0; i < RemainLengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
64
        {
65
66
            ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, current_ordered_id.PushBack(i));
Chao Liu's avatar
Chao Liu committed
67
68
69
70
        }
    }
};

71
72
template <class Orders>
struct ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
73
74
{
    // F signature: F(Array<...> multi_id)
75
76
77
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Chao Liu's avatar
Chao Liu committed
78
    {
79
80
        // retrive unordered Id
        f(reorder_array_given_old2new(current_ordered_id, Orders{}));
Chao Liu's avatar
Chao Liu committed
81
82
83
    }
};

Chao Liu's avatar
Chao Liu committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
} // namespace detail

// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
    __host__ __device__ constexpr static_ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

    // F signature: F(Sequence<...> multi_id)
    // multi_id is the unordered multi-index
    template <class F>
    __host__ __device__ constexpr void operator()(F f) const
    {
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
        detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
    }
};

109
110
111
112
113
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
Chao Liu's avatar
Chao Liu committed
114
115
struct ford
{
116
117
118
119
120
121
    __host__ __device__ constexpr ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

Chao Liu's avatar
Chao Liu committed
122
    // F signature: F(Array<...> multi_id)
123
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
124
    template <class F>
Chao Liu's avatar
Chao Liu committed
125
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
126
    {
127
128
129
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});

        for(index_t i = 0; i < ordered_lengths.Front(); ++i)
Chao Liu's avatar
Chao Liu committed
130
        {
Chao Liu's avatar
Chao Liu committed
131
132
            detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
                                                                              Array<index_t, 1>{i});
Chao Liu's avatar
Chao Liu committed
133
134
135
        }
    }
};
136
137
138

} // namespace ck
#endif