tuple.hpp 9 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
7
#include "ck/utility/is_static.hpp"
#include "ck/utility/print.hpp"
8
9
10
11
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
Chao Liu's avatar
Chao Liu committed
12
13
14

namespace ck {

Chao Liu's avatar
Chao Liu committed
15
namespace detail {
Chao Liu's avatar
Chao Liu committed
16

Chao Liu's avatar
Chao Liu committed
17
18
19
template <index_t>
struct TupleElementKey
{
Chao Liu's avatar
Chao Liu committed
20
    __host__ __device__ constexpr TupleElementKey() = default;
Chao Liu's avatar
Chao Liu committed
21
};
Chao Liu's avatar
Chao Liu committed
22

Chao Liu's avatar
Chao Liu committed
23
template <typename Key, typename Data>
24
struct TupleElementKeyData
Chao Liu's avatar
Chao Liu committed
25
{
26
27
    using DataType = Data;

28
29
30
31
32
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
    __host__ __device__ constexpr TupleElementKeyData() = default;
#else
    __host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
Chao Liu's avatar
Chao Liu committed
33

34
35
36
37
    template <typename T,
              typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
                                 bool>::type = false>
    __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
Chao Liu's avatar
Chao Liu committed
38
39
40
    {
    }

41
    DataType mData;
Chao Liu's avatar
Chao Liu committed
42
43
};

44
// for read access of tuple element
Chao Liu's avatar
Chao Liu committed
45
template <typename Key, typename Data>
46
__host__ __device__ constexpr const Data&
47
get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
48
{
Chao Liu's avatar
Chao Liu committed
49
    return static_cast<const Data&>(x.mData);
Chao Liu's avatar
Chao Liu committed
50
}
Chao Liu's avatar
Chao Liu committed
51

52
// for write access of tuple element
Chao Liu's avatar
Chao Liu committed
53
template <typename Key, typename Data>
54
55
__host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
56
{
Chao Liu's avatar
Chao Liu committed
57
58
    return x.mData;
}
Chao Liu's avatar
Chao Liu committed
59

Chao Liu's avatar
Chao Liu committed
60
// TODO: not sure the use of reference is correct
Chao Liu's avatar
Chao Liu committed
61
template <typename Key, typename Data>
62
63
__host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
Chao Liu's avatar
Chao Liu committed
64
{
Chao Liu's avatar
Chao Liu committed
65
66
    return static_cast<Data&&>(x.mData);
}
Chao Liu's avatar
Chao Liu committed
67

68
69
70
71
72
73
74
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
    return std::forward(x.mData);
}

Chao Liu's avatar
Chao Liu committed
75
76
77
78
template <typename Indices, typename... Xs>
struct TupleImpl;

template <index_t... Is, typename... Xs>
79
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    __host__ __device__ constexpr TupleImpl() = default;

Chao Liu's avatar
Chao Liu committed
83
84
    template <typename Y,
              typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
Chao Liu's avatar
Chao Liu committed
85
                                     !is_same<remove_cvref_t<Y>, TupleImpl>::value,
Chao Liu's avatar
Chao Liu committed
86
                                 bool>::type = false>
Chao Liu's avatar
Chao Liu committed
87
    __host__ __device__ constexpr TupleImpl(Y&& y)
88
        : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
Chao Liu's avatar
Chao Liu committed
89
90
91
    {
    }

Chao Liu's avatar
Chao Liu committed
92
    template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
Chao Liu's avatar
Chao Liu committed
93
    __host__ __device__ constexpr TupleImpl(Ys&&... ys)
94
        : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
Chao Liu's avatar
Chao Liu committed
95
    {
Chao Liu's avatar
Chao Liu committed
96
97
        static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
                      "wrong! inconsistent size");
Chao Liu's avatar
Chao Liu committed
98
99
100
101
102
    }

    __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }

    template <index_t I>
103
    __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
Chao Liu's avatar
Chao Liu committed
104
    {
105
        return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
Chao Liu's avatar
Chao Liu committed
106
107
108
    }

    template <index_t I>
109
    __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
Chao Liu's avatar
Chao Liu committed
110
    {
111
        return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
Chao Liu's avatar
Chao Liu committed
112
    }
Chao Liu's avatar
Chao Liu committed
113
114
};

Chao Liu's avatar
Chao Liu committed
115
116
117
118
} // namespace detail

template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
Chao Liu's avatar
Chao Liu committed
119
{
Chao Liu's avatar
Chao Liu committed
120
121
122
    using base =
        detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;

Chao Liu's avatar
Chao Liu committed
123
124
125
    __host__ __device__ constexpr Tuple() = default;

    template <typename Y,
Chao Liu's avatar
Chao Liu committed
126
              typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
Chao Liu's avatar
Chao Liu committed
127
                                 bool>::type = false>
Chao Liu's avatar
Chao Liu committed
128
    __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
Chao Liu's avatar
Chao Liu committed
129
130
131
    {
    }

Chao Liu's avatar
Chao Liu committed
132
    template <typename... Ys,
Chao Liu's avatar
Chao Liu committed
133
134
              typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
                  false>
Chao Liu's avatar
Chao Liu committed
135
136
137
138
139
140
    __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
    {
    }

    __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }

Chao Liu's avatar
Chao Liu committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    // read access
    template <index_t I>
    __host__ __device__ constexpr const auto& At() const
    {
        static_assert(I < base::Size(), "wrong! out of range");
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
    }

    // write access
    template <index_t I>
    __host__ __device__ constexpr auto& At()
    {
        static_assert(I < base::Size(), "wrong! out of range");
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
    }

157
    // read access
Chao Liu's avatar
Chao Liu committed
158
159
160
161
    template <index_t I>
    __host__ __device__ constexpr const auto& At(Number<I>) const
    {
        static_assert(I < base::Size(), "wrong! out of range");
162
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
163
164
    }

165
    // write access
Chao Liu's avatar
Chao Liu committed
166
167
168
169
    template <index_t I>
    __host__ __device__ constexpr auto& At(Number<I>)
    {
        static_assert(I < base::Size(), "wrong! out of range");
170
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
171
    }
Chao Liu's avatar
Chao Liu committed
172

173
    // read access
Chao Liu's avatar
Chao Liu committed
174
175
176
177
178
    template <index_t I>
    __host__ __device__ constexpr const auto& operator[](Number<I> i) const
    {
        return At(i);
    }
179

180
    // write access
Chao Liu's avatar
Chao Liu committed
181
182
183
184
185
    template <index_t I>
    __host__ __device__ constexpr auto& operator()(Number<I> i)
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
186

Chao Liu's avatar
Chao Liu committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    // WARNING: needed by compiler for C++ structured binding support only, don't use this function!
    template <std::size_t I>
    __host__ __device__ constexpr const auto& get() const
    {
        return this->template At<I>();
    }

    // WARNING: needed bu compiler for C++ structured binding support only, don't use this function!
    template <std::size_t I>
    __host__ __device__ constexpr auto& get()
    {
        return this->template At<I>();
    }

Chao Liu's avatar
Chao Liu committed
201
202
203
204
    template <typename T>
    __host__ __device__ constexpr auto operator=(const T& a)
    {
        static_assert(T::Size() == Size(), "wrong! size not the same");
Chao Liu's avatar
Chao Liu committed
205

Chao Liu's avatar
Chao Liu committed
206
        static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
Chao Liu's avatar
Chao Liu committed
207

Chao Liu's avatar
Chao Liu committed
208
209
        return *this;
    }
zjing14's avatar
zjing14 committed
210

Chao Liu's avatar
Chao Liu committed
211
212
213
214
215
216
217
218
219
220
221
222
    __host__ __device__ static constexpr bool IsStatic()
    {
        bool flag = true;

        static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
            flag &= is_static_v<remove_cvref_t<type_pack_element<i.value, Xs...>>>;
        });

        return flag;
    }

    // FIXME: remove
zjing14's avatar
zjing14 committed
223
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
Chao Liu's avatar
Chao Liu committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

    __host__ __device__ void Print() const
    {
        printf("Tuple{size: %d, data: [", static_cast<index_t>(Size()));

        static_for<0, Size(), 1>{}([&](auto i) {
            print(At(i));

            if(i < Size() - 1)
            {
                printf(", ");
            }
        });

        printf("]}");
    }
Chao Liu's avatar
Chao Liu committed
240
};
241

242
243
244
245
246
247
248
249
250
251
252
253
254
template <>
struct Tuple<>
{
    __host__ __device__ constexpr Tuple() = default;

    __host__ __device__ static constexpr index_t Size() { return 0; }

    template <typename T>
    __host__ __device__ constexpr auto operator=(const T&)
    {
        return *this;
    }

Chao Liu's avatar
Chao Liu committed
255
256
257
    __host__ __device__ static constexpr bool IsStatic() { return true; }

    // FIXME: remove
258
259
260
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};

Chao Liu's avatar
Chao Liu committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
template <typename... Xs>
__host__ __device__ constexpr bool operator==(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
{
    bool same = true;

    static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
        if(a[i] != b[i])
        {
            same = false;
        }
    });

    return same;
}

template <typename... Xs>
__host__ __device__ constexpr bool operator!=(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
{
    return !(a == b);
}

282
283
284
template <index_t I, typename TTuple>
struct tuple_element
{
285
286
    // type should keep the cv/ref qualifier of original tuple element
    using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
287
288
289
290
291
};

template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;

Chao Liu's avatar
Chao Liu committed
292
293
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
Chao Liu's avatar
Chao Liu committed
294
{
Chao Liu's avatar
Chao Liu committed
295
    return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
Chao Liu's avatar
Chao Liu committed
296
297
}

298
299
300
301
302
303
304
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr Tuple<Args&...> tie(Args&... args) noexcept
{
    return {args...};
}

Chao Liu's avatar
Chao Liu committed
305
} // namespace ck
Chao Liu's avatar
Chao Liu committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

namespace std {

// WARNING: needed by compiler for C++ structured binding support only, don't use this
template <typename... Ts>
struct tuple_size<ck::Tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};

// WARNING: needed by compiler for C++ structured binding support only, don't use this
template <std::size_t I, typename... Ts>
struct tuple_element<I, ck::Tuple<Ts...>> : ck::tuple_element<I, ck::Tuple<Ts...>>
{
};

} // namespace std