tuple_helper.hpp 6.78 KB
Newer Older
aska-0096's avatar
aska-0096 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "functional4.hpp"
#include "tuple.hpp"

namespace ck {

template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
    return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
                  typename arithmetic_sequence_gen<0, N, 1>::type{});
}

template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
{
    return unpack([&f](auto&&... is) { return tie(f(is)...); },
                  typename arithmetic_sequence_gen<0, N, 1>::type{});
}

// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
                                                             const Tuple<Y&...>& ty)
{
    return unpack2(
        [&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
        tx,
        ty);
}

namespace detail {

template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
    return make_tuple(f(x.At(Number<Is>{}))...);
}

template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
    return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}

template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
    return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}

} // namespace detail

template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
    return detail::transform_tuples_impl(
        f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}

template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
    return detail::transform_tuples_impl(
        f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}

template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
    return detail::transform_tuples_impl(
        f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}

} // namespace ck

// Macro function
// convert constexpr Array to Tuple of Number
#define TO_TUPLE_OF_NUMBER(arr, n)                                                              \
    [&arr, &n] {                                                                                \
        static_assert(arr.Size() >= n, "wrong! out of bound");                                  \
                                                                                                \
        static_assert(n < 7, "not implemented");                                                \
                                                                                                \
        if constexpr(n == 0)                                                                    \
        {                                                                                       \
            return ck::Tuple<>{};                                                               \
        }                                                                                       \
        else if constexpr(n == 1)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>>{};                                                 \
        }                                                                                       \
        else if constexpr(n == 2)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>, Number<arr[1]>>{};                                 \
        }                                                                                       \
        else if constexpr(n == 3)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>, Number<arr[1]>, Number<arr[2]>>{};                 \
        }                                                                                       \
        else if constexpr(n == 4)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>, Number<arr[1]>, Number<arr[2]>, Number<arr[3]>>{}; \
        }                                                                                       \
        else if constexpr(n == 5)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>,                                                    \
                             Number<arr[1]>,                                                    \
                             Number<arr[2]>,                                                    \
                             Number<arr[3]>,                                                    \
                             Number<arr[4]>>{};                                                 \
        }                                                                                       \
        else if constexpr(n == 6)                                                               \
        {                                                                                       \
            return ck::Tuple<Number<arr[0]>,                                                    \
                             Number<arr[1]>,                                                    \
                             Number<arr[2]>,                                                    \
                             Number<arr[3]>,                                                    \
                             Number<arr[4]>,                                                    \
                             Number<arr[5]>>{};                                                 \
        }                                                                                       \
    }()