tuple.hpp 6.47 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
arai713's avatar
arai713 committed
2
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

6
7
8
9
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
Chao Liu's avatar
Chao Liu committed
10
11
12

namespace ck {

Chao Liu's avatar
Chao Liu committed
13
namespace detail {
Chao Liu's avatar
Chao Liu committed
14

Chao Liu's avatar
Chao Liu committed
15
16
17
template <index_t>
struct TupleElementKey
{
Chao Liu's avatar
Chao Liu committed
18
    __host__ __device__ constexpr TupleElementKey() = default;
Chao Liu's avatar
Chao Liu committed
19
};
Chao Liu's avatar
Chao Liu committed
20

Chao Liu's avatar
Chao Liu committed
21
template <typename Key, typename Data>
22
struct TupleElementKeyData
Chao Liu's avatar
Chao Liu committed
23
{
24
25
    using DataType = Data;

26
27
28
29
30
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
    __host__ __device__ constexpr TupleElementKeyData() = default;
#else
    __host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
Chao Liu's avatar
Chao Liu committed
31

32
33
34
    template <typename T,
              typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
                                 bool>::type = false>
arai713's avatar
arai713 committed
35
    __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
Chao Liu's avatar
Chao Liu committed
36
37
38
    {
    }

39
    DataType mData;
Chao Liu's avatar
Chao Liu committed
40
41
};

42
// for read access of tuple element
Chao Liu's avatar
Chao Liu committed
43
template <typename Key, typename Data>
44
__host__ __device__ constexpr const Data&
45
get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
46
{
Chao Liu's avatar
Chao Liu committed
47
    return static_cast<const Data&>(x.mData);
Chao Liu's avatar
Chao Liu committed
48
}
Chao Liu's avatar
Chao Liu committed
49

50
// for write access of tuple element
Chao Liu's avatar
Chao Liu committed
51
template <typename Key, typename Data>
52
53
__host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
54
{
Chao Liu's avatar
Chao Liu committed
55
56
    return x.mData;
}
Chao Liu's avatar
Chao Liu committed
57

Chao Liu's avatar
Chao Liu committed
58
// TODO: not sure the use of reference is correct
Chao Liu's avatar
Chao Liu committed
59
template <typename Key, typename Data>
60
61
__host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
Chao Liu's avatar
Chao Liu committed
62
{
Chao Liu's avatar
Chao Liu committed
63
64
    return static_cast<Data&&>(x.mData);
}
Chao Liu's avatar
Chao Liu committed
65

66
67
68
69
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
arai713's avatar
arai713 committed
70
    return ck::forward(x.mData);
71
72
}

Chao Liu's avatar
Chao Liu committed
73
74
75
76
template <typename Indices, typename... Xs>
struct TupleImpl;

template <index_t... Is, typename... Xs>
77
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
Chao Liu's avatar
Chao Liu committed
78
{
Chao Liu's avatar
Chao Liu committed
79
80
    __host__ __device__ constexpr TupleImpl() = default;

Chao Liu's avatar
Chao Liu committed
81
82
    template <typename Y,
              typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
Chao Liu's avatar
Chao Liu committed
83
                                     !is_same<remove_cvref_t<Y>, TupleImpl>::value,
Chao Liu's avatar
Chao Liu committed
84
                                 bool>::type = false>
Chao Liu's avatar
Chao Liu committed
85
    __host__ __device__ constexpr TupleImpl(Y&& y)
arai713's avatar
arai713 committed
86
        : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
Chao Liu's avatar
Chao Liu committed
87
88
89
    {
    }

Chao Liu's avatar
Chao Liu committed
90
    template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
Chao Liu's avatar
Chao Liu committed
91
    __host__ __device__ constexpr TupleImpl(Ys&&... ys)
arai713's avatar
arai713 committed
92
        : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
Chao Liu's avatar
Chao Liu committed
93
    {
Chao Liu's avatar
Chao Liu committed
94
95
        static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
                      "wrong! inconsistent size");
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
    }

    __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }

    template <index_t I>
101
    __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
Chao Liu's avatar
Chao Liu committed
102
    {
103
        return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
Chao Liu's avatar
Chao Liu committed
104
105
106
    }

    template <index_t I>
107
    __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
Chao Liu's avatar
Chao Liu committed
108
    {
109
        return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
Chao Liu's avatar
Chao Liu committed
110
    }
Chao Liu's avatar
Chao Liu committed
111
112
};

Chao Liu's avatar
Chao Liu committed
113
114
115
116
} // namespace detail

template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
Chao Liu's avatar
Chao Liu committed
117
{
Chao Liu's avatar
Chao Liu committed
118
119
120
    using base =
        detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;

Chao Liu's avatar
Chao Liu committed
121
122
123
    __host__ __device__ constexpr Tuple() = default;

    template <typename Y,
Chao Liu's avatar
Chao Liu committed
124
              typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
Chao Liu's avatar
Chao Liu committed
125
                                 bool>::type = false>
arai713's avatar
arai713 committed
126
    __host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
Chao Liu's avatar
Chao Liu committed
127
128
129
    {
    }

Chao Liu's avatar
Chao Liu committed
130
    template <typename... Ys,
Chao Liu's avatar
Chao Liu committed
131
132
              typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
                  false>
arai713's avatar
arai713 committed
133
    __host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
    {
    }

    __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }

139
    // read access
Chao Liu's avatar
Chao Liu committed
140
141
142
143
    template <index_t I>
    __host__ __device__ constexpr const auto& At(Number<I>) const
    {
        static_assert(I < base::Size(), "wrong! out of range");
144
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
145
146
    }

147
    // write access
Chao Liu's avatar
Chao Liu committed
148
149
150
151
    template <index_t I>
    __host__ __device__ constexpr auto& At(Number<I>)
    {
        static_assert(I < base::Size(), "wrong! out of range");
152
        return base::GetElementDataByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
153
    }
Chao Liu's avatar
Chao Liu committed
154

155
    // read access
Chao Liu's avatar
Chao Liu committed
156
157
158
159
160
    template <index_t I>
    __host__ __device__ constexpr const auto& operator[](Number<I> i) const
    {
        return At(i);
    }
161

162
    // write access
Chao Liu's avatar
Chao Liu committed
163
164
165
166
167
    template <index_t I>
    __host__ __device__ constexpr auto& operator()(Number<I> i)
    {
        return At(i);
    }
Chao Liu's avatar
Chao Liu committed
168

Chao Liu's avatar
Chao Liu committed
169
170
171
172
    template <typename T>
    __host__ __device__ constexpr auto operator=(const T& a)
    {
        static_assert(T::Size() == Size(), "wrong! size not the same");
Chao Liu's avatar
Chao Liu committed
173

Chao Liu's avatar
Chao Liu committed
174
        static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
Chao Liu's avatar
Chao Liu committed
175

Chao Liu's avatar
Chao Liu committed
176
177
        return *this;
    }
zjing14's avatar
zjing14 committed
178
179

    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
zjing14's avatar
zjing14 committed
180
181

    __host__ __device__ static constexpr bool IsTuple() { return true; }
Chao Liu's avatar
Chao Liu committed
182
};
183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
template <>
struct Tuple<>
{
    __host__ __device__ constexpr Tuple() = default;

    __host__ __device__ static constexpr index_t Size() { return 0; }

    template <typename T>
    __host__ __device__ constexpr auto operator=(const T&)
    {
        return *this;
    }

    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};

template <index_t I, typename TTuple>
struct tuple_element
{
203
204
    // type should keep the cv/ref qualifier of original tuple element
    using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
205
206
207
208
209
};

template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;

Chao Liu's avatar
Chao Liu committed
210
211
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
Chao Liu's avatar
Chao Liu committed
212
{
arai713's avatar
arai713 committed
213
    return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
Chao Liu's avatar
Chao Liu committed
214
215
}

216
217
218
219
220
221
222
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr Tuple<Args&...> tie(Args&... args) noexcept
{
    return {args...};
}

Chao Liu's avatar
Chao Liu committed
223
} // namespace ck