functional3.hpp 4.33 KB
Newer Older
1
2
3
#ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include "functional.hpp"
#include "functional2.hpp"
#include "Sequence.hpp"
#include "Array.hpp"
Chao Liu's avatar
Chao Liu committed
8

9
10
namespace ck {

Chao Liu's avatar
tweak  
Chao Liu committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
template <class>
struct is_static : integral_constant<bool, false>
{
};

template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};

template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};

Chao Liu's avatar
Chao Liu committed
26
// RemainLengths: Sequence<...>
27
28
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
29
30
struct static_ford_impl
{
31
    __host__ __device__ constexpr static_ford_impl()
Chao Liu's avatar
Chao Liu committed
32
33
    {
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
34
    }
Chao Liu's avatar
Chao Liu committed
35

36
37
38
39
40
    // F signature: F(Sequence<...>)
    // CurrentOrderedId: Sequence<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
    {
Chao Liu's avatar
Chao Liu committed
41
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
42
43
            static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, CurrentOrderedId::PushBack(I));
Chao Liu's avatar
Chao Liu committed
44
45
46
47
        });
    }
};

48
49
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
50
{
51
52
53
54
    // F signature: F(Sequence<...>)
    // OrderedId: Sequence<...>
    template <class F, class OrderedId>
    __host__ __device__ constexpr void operator()(F f, OrderedId) const
Chao Liu's avatar
Chao Liu committed
55
    {
56
57
        // retrive unordered Id
        f(OrderedId::ReorderGivenOld2New(Orders{}));
Chao Liu's avatar
Chao Liu committed
58
59
60
    }
};

61
62
63
64
65
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
Chao Liu's avatar
Chao Liu committed
66
67
struct static_ford
{
68
69
70
71
72
73
    __host__ __device__ constexpr static_ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

Chao Liu's avatar
Chao Liu committed
74
    // F signature: F(Sequence<...> multi_id)
75
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
76
    template <class F>
Chao Liu's avatar
Chao Liu committed
77
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
78
    {
79
80
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
        static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
Chao Liu's avatar
Chao Liu committed
81
82
83
    }
};

84
85
86
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
87
88
struct ford_impl
{
89
    __host__ __device__ constexpr ford_impl()
Chao Liu's avatar
Chao Liu committed
90
    {
91
92
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
    }
Chao Liu's avatar
Chao Liu committed
93

94
95
96
97
98
99
    // F signature: F(Array<...> multi_id)
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
    {
        for(index_t i = 0; i < RemainLengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
100
        {
101
102
            ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, current_ordered_id.PushBack(i));
Chao Liu's avatar
Chao Liu committed
103
104
105
106
        }
    }
};

107
108
template <class Orders>
struct ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
109
110
{
    // F signature: F(Array<...> multi_id)
111
112
113
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Chao Liu's avatar
Chao Liu committed
114
    {
115
116
        // retrive unordered Id
        f(reorder_array_given_old2new(current_ordered_id, Orders{}));
Chao Liu's avatar
Chao Liu committed
117
118
119
    }
};

120
121
122
123
124
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
Chao Liu's avatar
Chao Liu committed
125
126
struct ford
{
127
128
129
130
131
132
    __host__ __device__ constexpr ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

Chao Liu's avatar
Chao Liu committed
133
    // F signature: F(Array<...> multi_id)
134
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
135
    template <class F>
Chao Liu's avatar
Chao Liu committed
136
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
137
    {
138
        for(index_t i = 0; i < Lengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
139
        {
140
            ford_impl<decltype(Lengths::PopFront()), Orders>{}(f, Array<index_t, 1>{i});
Chao Liu's avatar
Chao Liu committed
141
142
143
        }
    }
};
144
145
146

} // namespace ck
#endif