functional3.hpp 4.45 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
#pragma once
5

Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
#include "ck/ck.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/multi_index.hpp"
Chao Liu's avatar
Chao Liu committed
11

12
13
namespace ck {

Chao Liu's avatar
Chao Liu committed
14
namespace detail {
Chao Liu's avatar
tweak  
Chao Liu committed
15

Chao Liu's avatar
Chao Liu committed
16
// RemainLengths: Sequence<...>
17
18
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
19
20
struct static_ford_impl
{
21
    __host__ __device__ constexpr static_ford_impl()
Chao Liu's avatar
Chao Liu committed
22
23
    {
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
24
    }
Chao Liu's avatar
Chao Liu committed
25

26
27
28
29
30
    // F signature: F(Sequence<...>)
    // CurrentOrderedId: Sequence<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
    {
Chao Liu's avatar
Chao Liu committed
31
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
32
33
            static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, CurrentOrderedId::PushBack(I));
Chao Liu's avatar
Chao Liu committed
34
35
36
37
        });
    }
};

38
39
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
40
{
41
42
43
44
    // F signature: F(Sequence<...>)
    // OrderedId: Sequence<...>
    template <class F, class OrderedId>
    __host__ __device__ constexpr void operator()(F f, OrderedId) const
Chao Liu's avatar
Chao Liu committed
45
    {
46
47
        // retrive unordered Id
        f(OrderedId::ReorderGivenOld2New(Orders{}));
Chao Liu's avatar
Chao Liu committed
48
49
50
    }
};

51
52
53
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
54
55
struct ford_impl
{
56
    __host__ __device__ constexpr ford_impl()
Chao Liu's avatar
Chao Liu committed
57
    {
58
59
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
    }
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
65
66
    // F signature: F(Array<...> multi_id)
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
    {
        for(index_t i = 0; i < RemainLengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
67
        {
68
            ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
Chao Liu's avatar
Chao Liu committed
69
                f, container_push_back(current_ordered_id, i));
Chao Liu's avatar
Chao Liu committed
70
71
72
73
        }
    }
};

74
75
template <class Orders>
struct ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
76
77
{
    // F signature: F(Array<...> multi_id)
78
79
80
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Chao Liu's avatar
Chao Liu committed
81
    {
82
        // retrive unordered Id
Chao Liu's avatar
Chao Liu committed
83
        f(container_reorder_given_old2new(current_ordered_id, Orders{}));
Chao Liu's avatar
Chao Liu committed
84
85
86
    }
};

Chao Liu's avatar
Chao Liu committed
87
88
} // namespace detail

Chao Liu's avatar
Chao Liu committed
89
90
91
92
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford
// will loop over each
Chao Liu's avatar
Chao Liu committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
    __host__ __device__ constexpr static_ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

    // F signature: F(Sequence<...> multi_id)
    // multi_id is the unordered multi-index
    template <class F>
    __host__ __device__ constexpr void operator()(F f) const
    {
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
        detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
    }
};

Chao Liu's avatar
Chao Liu committed
114
115
116
117
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop
// over each
118
119
120
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
Chao Liu's avatar
Chao Liu committed
121
122
struct ford
{
123
124
125
126
127
128
    __host__ __device__ constexpr ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

Chao Liu's avatar
Chao Liu committed
129
    // F signature: F(Array<...> multi_id)
130
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
131
    template <class F>
Chao Liu's avatar
Chao Liu committed
132
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
133
    {
134
135
136
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});

        for(index_t i = 0; i < ordered_lengths.Front(); ++i)
Chao Liu's avatar
Chao Liu committed
137
        {
Chao Liu's avatar
Chao Liu committed
138
            detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Chao Liu's avatar
Chao Liu committed
139
                                                                              make_multi_index(i));
Chao Liu's avatar
Chao Liu committed
140
141
142
        }
    }
};
143
144

} // namespace ck