functional3.hpp 4.42 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
#pragma once
5

Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
#include "ck/ck.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/multi_index.hpp"
Chao Liu's avatar
Chao Liu committed
11

12
13
namespace ck {

Chao Liu's avatar
Chao Liu committed
14
namespace detail {
Chao Liu's avatar
tweak  
Chao Liu committed
15

Chao Liu's avatar
Chao Liu committed
16
// RemainLengths: Sequence<...>
17
18
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
19
20
struct static_ford_impl
{
21
    __host__ __device__ constexpr static_ford_impl()
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        static_assert(RemainLengths::Size() > 0, "wrong! should not get here");
24
    }
Chao Liu's avatar
Chao Liu committed
25

26
27
28
29
30
    // F signature: F(Sequence<...>)
    // CurrentOrderedId: Sequence<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
    {
Chao Liu's avatar
Chao Liu committed
31
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
32
33
            static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, CurrentOrderedId::PushBack(I));
Chao Liu's avatar
Chao Liu committed
34
35
36
37
        });
    }
};

38
39
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
40
{
41
42
43
44
    // F signature: F(Sequence<...>)
    // OrderedId: Sequence<...>
    template <class F, class OrderedId>
    __host__ __device__ constexpr void operator()(F f, OrderedId) const
Chao Liu's avatar
Chao Liu committed
45
    {
46
47
        // retrive unordered Id
        f(OrderedId::ReorderGivenOld2New(Orders{}));
Chao Liu's avatar
Chao Liu committed
48
49
50
    }
};

51
52
53
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
54
55
struct ford_impl
{
56
    __host__ __device__ constexpr ford_impl()
Chao Liu's avatar
Chao Liu committed
57
    {
Chao Liu's avatar
Chao Liu committed
58
        static_assert(RemainLengths::Size() > 0, "wrong! should not get here");
59
    }
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
65
66
    // F signature: F(Array<...> multi_id)
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
    {
        for(index_t i = 0; i < RemainLengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
67
        {
68
            ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
Chao Liu's avatar
Chao Liu committed
69
                f, container_push_back(current_ordered_id, i));
Chao Liu's avatar
Chao Liu committed
70
71
72
73
        }
    }
};

74
75
template <class Orders>
struct ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
76
77
{
    // F signature: F(Array<...> multi_id)
78
79
80
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Chao Liu's avatar
Chao Liu committed
81
    {
82
        // retrive unordered Id
Chao Liu's avatar
Chao Liu committed
83
        f(container_reorder_given_old2new(current_ordered_id, Orders{}));
Chao Liu's avatar
Chao Liu committed
84
85
86
    }
};

Chao Liu's avatar
Chao Liu committed
87
88
} // namespace detail

Chao Liu's avatar
Chao Liu committed
89
90
91
92
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford
// will loop over each
Chao Liu's avatar
Chao Liu committed
93
94
// dimension
template <class Lengths,
Chao Liu's avatar
Chao Liu committed
95
          class Orders = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
Chao Liu's avatar
Chao Liu committed
96
97
98
99
struct static_ford
{
    __host__ __device__ constexpr static_ford()
    {
Chao Liu's avatar
Chao Liu committed
100
101
        static_assert(Lengths::Size() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::Size() == Orders::Size(), "wrong! inconsistent size");
Chao Liu's avatar
Chao Liu committed
102
103
104
105
106
107
108
109
110
111
112
113
    }

    // F signature: F(Sequence<...> multi_id)
    // multi_id is the unordered multi-index
    template <class F>
    __host__ __device__ constexpr void operator()(F f) const
    {
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
        detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
    }
};

Chao Liu's avatar
Chao Liu committed
114
115
116
117
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop
// over each
118
119
// dimension
template <class Lengths,
Chao Liu's avatar
Chao Liu committed
120
          class Orders = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
Chao Liu's avatar
Chao Liu committed
121
122
struct ford
{
123
124
    __host__ __device__ constexpr ford()
    {
Chao Liu's avatar
Chao Liu committed
125
126
        static_assert(Lengths::Size() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::Size() == Orders::Size(), "wrong! inconsistent size");
127
128
    }

Chao Liu's avatar
Chao Liu committed
129
    // F signature: F(Array<...> multi_id)
130
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
131
    template <class F>
Chao Liu's avatar
Chao Liu committed
132
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
133
    {
134
135
136
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});

        for(index_t i = 0; i < ordered_lengths.Front(); ++i)
Chao Liu's avatar
Chao Liu committed
137
        {
Chao Liu's avatar
Chao Liu committed
138
            detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Chao Liu's avatar
Chao Liu committed
139
                                                                              make_multi_index(i));
Chao Liu's avatar
Chao Liu committed
140
141
142
        }
    }
};
143
144

} // namespace ck