"sims/net/vscode:/vscode.git/clone" did not exist on "32a74b8eb0183e6911e87292795f8a8968d214af"
common.hpp 6.51 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
6
7
8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstddef>
#include <cstdlib>
#include <iostream>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
9
10
#include <iterator>
#include <numeric>
11
#include <type_traits>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
12
13
14
15

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
16
#include "ck/utility/type.hpp"
Po-Yen, Chen's avatar
Po-Yen, Chen committed
17
18
19
20
21
22

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

23
24
using F16 = ck::half_t;

Po-Yen, Chen's avatar
Po-Yen, Chen committed
25
26
27
28
29
30
31
32
33
34
35
36
struct ExecutionConfig final
{
    bool do_verification = true;
    bool time_kernel     = false;
};

struct Problem final
{
    std::array<std::size_t, 4> shape = {4, 16, 32, 32};
    std::array<std::size_t, 4> axes  = {0, 2, 3, 1};
};

37
38
39
40
41
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

42
43
44
45
46
47
48
49
50
51
52
namespace detail {

template <typename T, typename = void>
struct is_iterator : std::false_type
{
};

template <typename T>
struct is_iterator<T,
                   std::void_t<decltype(*std::declval<T>()),
                               decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
53
                               decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
54
55
56
57
58
59
60
    : std::true_type
{
};

template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value;

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
struct Placeholder final
{
    template <typename T>
    constexpr inline operator T() const noexcept;
};

template <typename T, typename = void>
struct is_output_iterator : std::false_type
{
};

template <typename Iterator>
struct is_output_iterator<
    Iterator,
    std::void_t<decltype(*std::declval<Iterator>() = std::declval<Placeholder>())>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
};

template <typename Iterator>
struct is_random_access_iterator<Iterator,
                                 std::void_t<decltype(std::declval<Iterator>() + 1),
                                             decltype(std::declval<Iterator>() - 1),
                                             decltype(std::declval<Iterator>()[1])>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;

template <typename T, typename = void>
struct is_range : std::false_type
{
};

template <typename T>
struct is_range<T,
                std::void_t<decltype(begin(std::declval<T>())), decltype(end(std::declval<T>()))>>
    : std::bool_constant<is_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<T>()))>>>
{
};

template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;

115
116
117
118
119
120
121
122
123
124
125
126
127
128
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};

template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
    : std::bool_constant<is_range_v<Range>>
{
};

template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
template <typename Range, typename = void>
struct is_random_access_range : std::false_type
{
};

template <typename Range>
struct is_random_access_range<Range, std::void_t<>>
    : std::bool_constant<
          is_range_v<Range> &&
          is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};

template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;

} // namespace detail

template <typename Axes>
148
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
149
is_valid_axes(const Axes& axes)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
{
    using std::empty;
    if(empty(axes))
    {
        return false;
    }

    using std::begin, std::end;
    std::vector<std::size_t> copy(begin(axes), end(axes));

    std::sort(begin(copy), end(copy));
    const auto last = std::unique(begin(copy), end(copy));

    return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1);
}

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
                            detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
                            detail::is_output_iterator_v<OutputIterator>,
                        OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
    using std::size;
    assert(size(shape) == size(axes) && is_valid_axes(axes));

    for(const auto axis : axes)
    {
        *iter++ = shape[axis];
    }

    return iter;
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
184
185
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
186
187
188
    constexpr int num_execution_config_args = 2;
    constexpr int num_problem_args          = 8;

189
    assert(num_problem_args == size(problem.shape) + size(problem.axes));
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 1 + num_execution_config_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);
    }
    else if(argc == 1 + num_execution_config_args + num_problem_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);

        // read shape
206
        for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
207
208
209
210
211
        {
            problem.shape[idx] = std::stoi(argv[idx + 3]);
        }

        // read axes
212
213
214
215
216
217
        for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
        {
            problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]);
        }

        if(!is_valid_axes(problem.axes))
218
        {
219
220
221
222
223
224
            std::cerr << "invalid axes: ";
            std::copy(begin(problem.axes),
                      end(problem.axes),
                      std::ostream_iterator<std::size_t>(std::cerr, " "));
            std::cerr << std::endl;
            return false;
225
226
227
228
229
230
231
232
233
234
235
        }
    }
    else
    {
        std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
                  << "arg2: time kernel (0=no, 1=yes)" << std::endl
                  << "arg3 ~ arg6: shape for 4D tensor" << std::endl
                  << "arg7 ~ arg10: axes to permute" << std::endl;
        return false;
    }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
236
237
    return true;
}