tuple.hpp 3.82 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP

#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "type.hpp"
Chao Liu's avatar
Chao Liu committed
6
#include "sequence.hpp"
Chao Liu's avatar
Chao Liu committed
7
8
9

namespace ck {

Chao Liu's avatar
Chao Liu committed
10
namespace detail {
Chao Liu's avatar
Chao Liu committed
11

Chao Liu's avatar
Chao Liu committed
12
13
14
15
template <index_t>
struct TupleElementKey
{
};
Chao Liu's avatar
Chao Liu committed
16

Chao Liu's avatar
Chao Liu committed
17
18
19
template <typename Key, typename Data>
struct TupleElement
{
Chao Liu's avatar
Chao Liu committed
20
21
    __host__ __device__ explicit constexpr TupleElement() : mData() {}

Chao Liu's avatar
Chao Liu committed
22
23
    template <typename T>
    __host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
Chao Liu's avatar
Chao Liu committed
24
25
26
    {
    }

Chao Liu's avatar
Chao Liu committed
27
    Data mData;
Chao Liu's avatar
Chao Liu committed
28
29
};

Chao Liu's avatar
Chao Liu committed
30
31
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
32
{
Chao Liu's avatar
Chao Liu committed
33
34
    return x.mData;
}
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
37
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
Chao Liu's avatar
Chao Liu committed
38
{
Chao Liu's avatar
Chao Liu committed
39
40
    return x.mData;
}
Chao Liu's avatar
Chao Liu committed
41

Chao Liu's avatar
Chao Liu committed
42
43
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
Chao Liu's avatar
Chao Liu committed
44
{
Chao Liu's avatar
Chao Liu committed
45
46
    return static_cast<Data&&>(x.mData);
}
Chao Liu's avatar
Chao Liu committed
47

Chao Liu's avatar
Chao Liu committed
48
49
50
51
52
template <typename Indices, typename... Xs>
struct TupleImpl;

template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
Chao Liu's avatar
Chao Liu committed
53
{
Chao Liu's avatar
Chao Liu committed
54
55
56
57
    __host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
    {
    }

Chao Liu's avatar
Chao Liu committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    template <typename... Ys>
    __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
        : TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
    {
    }

    __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }

    template <index_t I>
    __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
    {
        return get_tuple_element<TupleElementKey<I>>(*this);
    }

    template <index_t I>
    __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
    {
        return get_tuple_element<TupleElementKey<I>>(*this);
    }
Chao Liu's avatar
Chao Liu committed
77
78
};

Chao Liu's avatar
Chao Liu committed
79
80
81
82
} // namespace detail

template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
Chao Liu's avatar
Chao Liu committed
83
{
Chao Liu's avatar
Chao Liu committed
84
85
86
87
88
89
90
91
92
93
94
95
    using base =
        detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;

    template <typename... Ys>
    __host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
    {
    }

    template <index_t I>
    __host__ __device__ constexpr const auto& At(Number<I>) const
    {
        static_assert(I < base::Size(), "wrong! out of range");
Chao Liu's avatar
Chao Liu committed
96
        return base::GetElementByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102
    }

    template <index_t I>
    __host__ __device__ constexpr auto& At(Number<I>)
    {
        static_assert(I < base::Size(), "wrong! out of range");
Chao Liu's avatar
Chao Liu committed
103
        return base::GetElementByKey(detail::TupleElementKey<I>{});
Chao Liu's avatar
Chao Liu committed
104
    }
Chao Liu's avatar
Chao Liu committed
105
106
};

Chao Liu's avatar
Chao Liu committed
107
108
109
110
111
112
113
114
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
    return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}

namespace detail {

Chao Liu's avatar
Chao Liu committed
115
template <typename F, typename X, index_t... Is>
116
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
Chao Liu's avatar
Chao Liu committed
117
118
119
120
{
    return make_tuple(f(x.At(Number<Is>{}))...);
}

121
122
123
124
125
126
127
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
    return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}

Chao Liu's avatar
Chao Liu committed
128
129
} // namespace detail

Chao Liu's avatar
Chao Liu committed
130
template <typename F, typename X>
131
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
Chao Liu's avatar
Chao Liu committed
132
{
133
    return detail::transform_tuples_impl(
Chao Liu's avatar
Chao Liu committed
134
        f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
Chao Liu's avatar
Chao Liu committed
135
136
}

137
138
139
140
141
142
143
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
    return detail::transform_tuples_impl(
        f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}

Chao Liu's avatar
Chao Liu committed
144
145
} // namespace ck
#endif