functional3.hpp 4.55 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
Chao Liu's avatar
Chao Liu committed
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
6

Chao Liu's avatar
Chao Liu committed
7
#pragma once
8

Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
#include "ck/ck.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/multi_index.hpp"
Chao Liu's avatar
Chao Liu committed
14

15
16
namespace ck {

Chao Liu's avatar
Chao Liu committed
17
namespace detail {
Chao Liu's avatar
tweak  
Chao Liu committed
18

Chao Liu's avatar
Chao Liu committed
19
// RemainLengths: Sequence<...>
20
21
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
22
23
struct static_ford_impl
{
24
    __host__ __device__ constexpr static_ford_impl()
Chao Liu's avatar
Chao Liu committed
25
26
    {
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
27
    }
Chao Liu's avatar
Chao Liu committed
28

29
30
31
32
33
    // F signature: F(Sequence<...>)
    // CurrentOrderedId: Sequence<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
    {
Chao Liu's avatar
Chao Liu committed
34
        static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
35
36
            static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
                f, CurrentOrderedId::PushBack(I));
Chao Liu's avatar
Chao Liu committed
37
38
39
40
        });
    }
};

41
42
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
43
{
44
45
46
47
    // F signature: F(Sequence<...>)
    // OrderedId: Sequence<...>
    template <class F, class OrderedId>
    __host__ __device__ constexpr void operator()(F f, OrderedId) const
Chao Liu's avatar
Chao Liu committed
48
    {
49
50
        // retrive unordered Id
        f(OrderedId::ReorderGivenOld2New(Orders{}));
Chao Liu's avatar
Chao Liu committed
51
52
53
    }
};

54
55
56
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
Chao Liu's avatar
Chao Liu committed
57
58
struct ford_impl
{
59
    __host__ __device__ constexpr ford_impl()
Chao Liu's avatar
Chao Liu committed
60
    {
61
62
        static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
    }
Chao Liu's avatar
Chao Liu committed
63

64
65
66
67
68
69
    // F signature: F(Array<...> multi_id)
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
    {
        for(index_t i = 0; i < RemainLengths::Front(); ++i)
Chao Liu's avatar
Chao Liu committed
70
        {
71
            ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
Chao Liu's avatar
Chao Liu committed
72
                f, container_push_back(current_ordered_id, i));
Chao Liu's avatar
Chao Liu committed
73
74
75
76
        }
    }
};

77
78
template <class Orders>
struct ford_impl<Sequence<>, Orders>
Chao Liu's avatar
Chao Liu committed
79
80
{
    // F signature: F(Array<...> multi_id)
81
82
83
    // CurrentOrderdId: Array<...>
    template <class F, class CurrentOrderedId>
    __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
Chao Liu's avatar
Chao Liu committed
84
    {
85
        // retrive unordered Id
Chao Liu's avatar
Chao Liu committed
86
        f(container_reorder_given_old2new(current_ordered_id, Orders{}));
Chao Liu's avatar
Chao Liu committed
87
88
89
    }
};

Chao Liu's avatar
Chao Liu committed
90
91
} // namespace detail

Chao Liu's avatar
Chao Liu committed
92
93
94
95
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford
// will loop over each
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
    __host__ __device__ constexpr static_ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

    // F signature: F(Sequence<...> multi_id)
    // multi_id is the unordered multi-index
    template <class F>
    __host__ __device__ constexpr void operator()(F f) const
    {
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
        detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
    }
};

Chao Liu's avatar
Chao Liu committed
117
118
119
120
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop
// over each
121
122
123
// dimension
template <class Lengths,
          class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
Chao Liu's avatar
Chao Liu committed
124
125
struct ford
{
126
127
128
129
130
131
    __host__ __device__ constexpr ford()
    {
        static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
        static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
    }

Chao Liu's avatar
Chao Liu committed
132
    // F signature: F(Array<...> multi_id)
133
    // multi_id is the unordered multi-index
Chao Liu's avatar
Chao Liu committed
134
    template <class F>
Chao Liu's avatar
Chao Liu committed
135
    __host__ __device__ constexpr void operator()(F f) const
Chao Liu's avatar
Chao Liu committed
136
    {
137
138
139
        constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});

        for(index_t i = 0; i < ordered_lengths.Front(); ++i)
Chao Liu's avatar
Chao Liu committed
140
        {
Chao Liu's avatar
Chao Liu committed
141
            detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Chao Liu's avatar
Chao Liu committed
142
                                                                              make_multi_index(i));
Chao Liu's avatar
Chao Liu committed
143
144
145
        }
    }
};
146
147

} // namespace ck
Umang Yadav's avatar
Umang Yadav committed
148
149

#pragma clang diagnostic pop