common.hpp 12.6 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
#include <cassert>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
7
8
#include <cstddef>
#include <cstdlib>
9
#include <cstring>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
10
#include <iostream>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
11
12
#include <iterator>
#include <numeric>
13
#include <type_traits>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
14
15
16

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
17
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
18
#include "ck/utility/type.hpp"
Po-Yen, Chen's avatar
Po-Yen, Chen committed
19
20
21
22
23
24

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

25
using F16 = ck::half_t;
26
using F32 = float;
27
using F64 = double;
28

Po-Yen, Chen's avatar
Po-Yen, Chen committed
29
30
31
32
33
34
35
36
struct ExecutionConfig final
{
    bool do_verification = true;
    bool time_kernel     = false;
};

struct Problem final
{
37
38
39
    static constexpr std::size_t NumDim = 3;

    using Shape = std::array<std::size_t, NumDim>;
40
41
42
43
44
45
46
47
48
49
50
    using Axes  = Shape;

    Problem() = delete;

    explicit Problem(const Shape& default_shape, const Axes& default_axes)
        : shape(default_shape), axes(default_axes)
    {
    }

    Shape shape;
    Axes axes;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
51
52
};

53
54
55
56
57
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

58
59
namespace detail {

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
template <typename Bundle, std::size_t Divisor>
struct get_bundled;

template <typename Bundle>
struct get_bundled<Bundle, 1>
{
    using type = Bundle;
};

template <>
struct get_bundled<F64, 2>
{
    using type = F32;
};

template <>
struct get_bundled<F64, 4>
{
    using type = F16;
};

template <>
struct get_bundled<F32, 2>
{
    using type = F16;
};

template <typename Bundle, std::size_t Divisor>
using get_bundled_t = typename get_bundled<Bundle, Divisor>::type;

90
91
92
93
94
95
96
97
98
template <typename T, typename = void>
struct is_iterator : std::false_type
{
};

template <typename T>
struct is_iterator<T,
                   std::void_t<decltype(*std::declval<T>()),
                               decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
99
                               decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
100
101
102
103
104
105
106
    : std::true_type
{
};

template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value;

107
108
109
110
111
112
struct Placeholder final
{
    template <typename T>
    constexpr inline operator T() const noexcept;
};

113
template <typename Iterator, typename = void>
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
struct is_output_iterator : std::false_type
{
};

template <typename Iterator>
struct is_output_iterator<
    Iterator,
    std::void_t<decltype(*std::declval<Iterator>() = std::declval<Placeholder>())>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
template <typename Iterator, typename = void>
struct is_bidirectional_iterator : std::false_type
{
};

template <typename Iterator>
struct is_bidirectional_iterator<
    Iterator,
    std::void_t<decltype(--std::declval<std::add_lvalue_reference_t<Iterator>>()),
                decltype(std::declval<std::add_lvalue_reference_t<Iterator>>()--)>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator<Iterator>::value;

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
};

template <typename Iterator>
struct is_random_access_iterator<Iterator,
                                 std::void_t<decltype(std::declval<Iterator>() + 1),
                                             decltype(std::declval<Iterator>() - 1),
                                             decltype(std::declval<Iterator>()[1])>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;

template <typename T, typename = void>
struct is_range : std::false_type
{
};

template <typename T>
struct is_range<T,
                std::void_t<decltype(begin(std::declval<T>())), decltype(end(std::declval<T>()))>>
    : std::bool_constant<is_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<T>()))>>>
{
};

template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;

178
179
180
181
182
183
184
185
186
187
188
189
190
191
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};

template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
    : std::bool_constant<is_range_v<Range>>
{
};

template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
template <typename Range, typename = void>
struct is_bidirectional_range : std::false_type
{
};

template <typename Range>
struct is_bidirectional_range<Range, std::void_t<>>
    : std::bool_constant<
          is_range_v<Range> &&
          is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};

template <typename Range>
inline constexpr bool is_bidirectional_range_v = is_bidirectional_range<Range>::value;

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
template <typename Range, typename = void>
struct is_random_access_range : std::false_type
{
};

template <typename Range>
struct is_random_access_range<Range, std::void_t<>>
    : std::bool_constant<
          is_range_v<Range> &&
          is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};

template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;

} // namespace detail

226
227
228
229
230
231
template <typename Range>
auto front(Range&& range) -> decltype(std::forward<Range>(range).front())
{
    return std::forward<Range>(range).front();
}

232
template <typename Axes>
233
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
234
is_valid_axes(const Axes& axes)
235
236
237
238
239
240
241
242
{
    using std::empty;
    if(empty(axes))
    {
        return false;
    }

    using std::begin, std::end;
243
    std::vector<std::size_t> sorted_axes(begin(axes), end(axes));
244

245
246
    std::sort(begin(sorted_axes), end(sorted_axes));
    const auto last = std::unique(begin(sorted_axes), end(sorted_axes));
247

248
249
    return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) &&
           (*std::prev(last) == size(axes) - 1);
250
251
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
252
253
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
254
    constexpr int num_execution_config_args = 2;
255
    constexpr int num_problem_args          = 2 * Problem::NumDim;
256

Po-Yen, Chen's avatar
Po-Yen, Chen committed
257
258
259
260
    if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
    {
        return false;
    }
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 1 + num_execution_config_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);
    }
    else if(argc == 1 + num_execution_config_args + num_problem_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);

        // read shape
277
        for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
278
        {
279
            problem.shape[idx] = std::stoi(argv[idx + (1 + num_execution_config_args)]);
280
281
282
        }

        // read axes
283
284
        for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
        {
285
286
            problem.axes[idx] =
                std::stoi(argv[idx + (1 + num_execution_config_args + size(problem.shape))]);
287
288
289
        }

        if(!is_valid_axes(problem.axes))
290
        {
291
292
293
294
295
296
            std::cerr << "invalid axes: ";
            std::copy(begin(problem.axes),
                      end(problem.axes),
                      std::ostream_iterator<std::size_t>(std::cerr, " "));
            std::cerr << std::endl;
            return false;
297
298
299
300
301
302
        }
    }
    else
    {
        std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
                  << "arg2: time kernel (0=no, 1=yes)" << std::endl
303
304
                  << "arg3 ~ arg5: shape for 3D tensor" << std::endl
                  << "arg6 ~ arg8: axes to permute" << std::endl;
305
306
307
        return false;
    }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
308
309
    return true;
}
310
311
312
313
314
315
316
317
318
319
320
321
322

template <typename Shape>
inline std::enable_if_t<detail::is_range_v<Shape>, bool> is_valid_shape(const Shape& shape)
{
    using std::begin, std::end;
    using std::empty;
    return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; });
}

template <typename Shape, typename Indices>
inline std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool>
is_valid_indices(const Shape& shape, const Indices& indices)
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
323
324
325
326
    if(!is_valid_shape(shape))
    {
        return false;
    }
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    using std::empty;
    if(empty(indices))
    {
        return false;
    }

    using std::size;
    if(size(shape) != size(indices))
    {
        return false;
    }

    using std::begin, std::end;

    auto dim = begin(shape);
    auto idx = begin(indices);
    for(; dim != end(shape) && idx != end(indices); ++dim, ++idx)
    {
        if(*dim <= *idx)
        {
            return false;
        }
    }

    return true;
}

template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
                            detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
                            detail::is_output_iterator_v<OutputIterator>,
                        OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
    using std::size;
363
    assert(size(shape) == size(axes));
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    assert(is_valid_shape(shape) && is_valid_axes(axes));

    for(const auto axis : axes)
    {
        *iter++ = shape[axis];
    }

    return iter;
}

template <typename Shape, typename Indices>
std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
                     detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
                 bool>
advance_indices(const Shape& shape, Indices& indices)
{
380
    using std::size;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
381
382
383
384
    if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
    {
        return false;
    }
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    bool carry = true;

    using std::rbegin, std::rend;
    auto dim = rbegin(shape);
    auto idx = rbegin(indices);
    for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx)
    {
        assert(*idx < *dim);

        *idx  = (*idx + carry);
        carry = ((*idx == *dim) ? (*idx = 0, true) : false);
    }

    return !carry;
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
402
403
404
405
406
407
template <typename Src, typename Axes, typename Functor, typename Dest>
std::enable_if_t<detail::is_random_access_range_v<Axes> && detail::is_sized_range_v<Axes> &&
                     std::is_invocable_v<Functor,
                                         std::add_lvalue_reference_t<Dest>,
                                         std::add_lvalue_reference_t<Src>>,
                 bool>
408
host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<Dest>& dest)
409
410
411
{
    const auto& shape            = src.mDesc.GetLengths();
    const auto& transposed_shape = dest.mDesc.GetLengths();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
412
413
414
415
416
417
    if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape)))
    {
        return false;
    }

    using std::size;
418
    if(!is_valid_axes(axes))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
419
420
421
    {
        return false;
    }
422
423
424
425

    static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
                  detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);

426
    if(size(shape) != size(transposed_shape))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    {
        return false;
    }

    static_assert(detail::is_random_access_range_v<ck::remove_cvref_t<decltype(shape)>> &&
                  detail::is_random_access_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
    {
        for(std::size_t idx = 0; idx < size(shape); ++idx)
        {
            if(transposed_shape[idx] != shape[axes[idx]])
            {
                return false;
            }
        }
    }
442

443
    std::vector<std::size_t> indices(size(shape), 0);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
444
445
446
447
    if(!is_valid_indices(shape, indices))
    {
        return false;
    }
448

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    if(size(shape) == 3)
    {
        do
        {
            Dest output = 0;
            functor(output, src(indices[0], indices[1], indices[2]));
            dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
        } while(advance_indices(shape, indices));
    }
    else if(size(shape) == 4)
    {
        do
        {
            Dest output = 0;
            functor(output, src(indices[0], indices[1], indices[2], indices[3]));
            dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
        } while(advance_indices(shape, indices));
    }
    else
468
    {
469
470
        return false;
    }
Po-Yen, Chen's avatar
Po-Yen, Chen committed
471
472

    return true;
473
}