common.hpp 13.9 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
#include <cassert>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
7
8
#include <cstddef>
#include <cstdlib>
9
#include <cstring>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
10
#include <iostream>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
11
12
#include <iterator>
#include <numeric>
13
#include <type_traits>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
14
15
16

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
17
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
18
#include "ck/utility/type.hpp"
Po-Yen, Chen's avatar
Po-Yen, Chen committed
19
20
21

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
22
#include "ck/library/utility/fill.hpp"
Po-Yen, Chen's avatar
Po-Yen, Chen committed
23
24
25
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

26
using F16 = ck::half_t;
27
using F32 = float;
28
using F64 = double;
29

Po-Yen, Chen's avatar
Po-Yen, Chen committed
30
31
32
33
34
35
36
37
struct ExecutionConfig final
{
    bool do_verification = true;
    bool time_kernel     = false;
};

struct Problem final
{
38
39
40
    static constexpr std::size_t NumDim = 3;

    using Shape = std::array<std::size_t, NumDim>;
41
42
43
44
45
46
47
48
49
50
51
    using Axes  = Shape;

    Problem() = delete;

    explicit Problem(const Shape& default_shape, const Axes& default_axes)
        : shape(default_shape), axes(default_axes)
    {
    }

    Shape shape;
    Axes axes;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
52
53
};

54
55
56
57
58
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

59
60
namespace detail {

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
template <typename Bundle, std::size_t Divisor>
struct get_bundled;

template <typename Bundle>
struct get_bundled<Bundle, 1>
{
    using type = Bundle;
};

template <>
struct get_bundled<F64, 2>
{
    using type = F32;
};

template <>
struct get_bundled<F64, 4>
{
    using type = F16;
};

template <>
struct get_bundled<F32, 2>
{
    using type = F16;
};

template <typename Bundle, std::size_t Divisor>
using get_bundled_t = typename get_bundled<Bundle, Divisor>::type;

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
template <typename Array, std::size_t Difference>
struct enlarge_array_size;

template <typename T, std::size_t Size, std::size_t Difference>
struct enlarge_array_size<std::array<T, Size>, Difference>
{
    using type = std::array<T, Size + Difference>;
};

template <typename Array, std::size_t Difference>
using enlarge_array_size_t = typename enlarge_array_size<Array, Difference>::type;

template <typename Array>
struct get_array_size;

template <typename T, std::size_t Size>
struct get_array_size<std::array<T, Size>> : std::integral_constant<std::size_t, Size>
{
};

template <typename Array>
inline constexpr std::size_t get_array_size_v = get_array_size<Array>::value;

114
115
116
117
118
119
120
121
122
template <typename T, typename = void>
struct is_iterator : std::false_type
{
};

template <typename T>
struct is_iterator<T,
                   std::void_t<decltype(*std::declval<T>()),
                               decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
123
                               decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
124
125
126
127
128
129
130
    : std::true_type
{
};

template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value;

131
132
133
134
135
136
struct Placeholder final
{
    template <typename T>
    constexpr inline operator T() const noexcept;
};

137
template <typename Iterator, typename = void>
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
struct is_output_iterator : std::false_type
{
};

template <typename Iterator>
struct is_output_iterator<
    Iterator,
    std::void_t<decltype(*std::declval<Iterator>() = std::declval<Placeholder>())>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
template <typename Iterator, typename = void>
struct is_bidirectional_iterator : std::false_type
{
};

template <typename Iterator>
struct is_bidirectional_iterator<
    Iterator,
    std::void_t<decltype(--std::declval<std::add_lvalue_reference_t<Iterator>>()),
                decltype(std::declval<std::add_lvalue_reference_t<Iterator>>()--)>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator<Iterator>::value;

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
};

template <typename Iterator>
struct is_random_access_iterator<Iterator,
                                 std::void_t<decltype(std::declval<Iterator>() + 1),
                                             decltype(std::declval<Iterator>() - 1),
                                             decltype(std::declval<Iterator>()[1])>>
    : std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;

template <typename T, typename = void>
struct is_range : std::false_type
{
};

template <typename T>
struct is_range<T,
                std::void_t<decltype(begin(std::declval<T>())), decltype(end(std::declval<T>()))>>
    : std::bool_constant<is_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<T>()))>>>
{
};

template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;

202
203
204
205
206
207
208
209
210
211
212
213
214
215
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};

template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
    : std::bool_constant<is_range_v<Range>>
{
};

template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
template <typename Range, typename = void>
struct is_bidirectional_range : std::false_type
{
};

template <typename Range>
struct is_bidirectional_range<Range, std::void_t<>>
    : std::bool_constant<
          is_range_v<Range> &&
          is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};

template <typename Range>
inline constexpr bool is_bidirectional_range_v = is_bidirectional_range<Range>::value;

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
template <typename Range, typename = void>
struct is_random_access_range : std::false_type
{
};

template <typename Range>
struct is_random_access_range<Range, std::void_t<>>
    : std::bool_constant<
          is_range_v<Range> &&
          is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};

template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;

} // namespace detail

250
251
252
253
254
255
template <typename Range>
auto front(Range&& range) -> decltype(std::forward<Range>(range).front())
{
    return std::forward<Range>(range).front();
}

256
template <typename Axes>
257
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
258
is_valid_axes(const Axes& axes)
259
260
261
262
263
264
265
266
{
    using std::empty;
    if(empty(axes))
    {
        return false;
    }

    using std::begin, std::end;
267
    std::vector<std::size_t> sorted_axes(begin(axes), end(axes));
268

269
270
    std::sort(begin(sorted_axes), end(sorted_axes));
    const auto last = std::unique(begin(sorted_axes), end(sorted_axes));
271

272
273
    return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) &&
           (*std::prev(last) == size(axes) - 1);
274
275
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
276
277
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
278
    constexpr int num_execution_config_args = 2;
279
    constexpr int num_problem_args          = 2 * Problem::NumDim;
280

Po-Yen, Chen's avatar
Po-Yen, Chen committed
281
282
283
284
    if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
    {
        return false;
    }
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 1 + num_execution_config_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);
    }
    else if(argc == 1 + num_execution_config_args + num_problem_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);

        // read shape
301
        for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
302
        {
303
            problem.shape[idx] = std::stoi(argv[idx + (1 + num_execution_config_args)]);
304
305
306
        }

        // read axes
307
308
        for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
        {
309
310
            problem.axes[idx] =
                std::stoi(argv[idx + (1 + num_execution_config_args + size(problem.shape))]);
311
312
313
        }

        if(!is_valid_axes(problem.axes))
314
        {
315
316
317
318
319
320
            std::cerr << "invalid axes: ";
            std::copy(begin(problem.axes),
                      end(problem.axes),
                      std::ostream_iterator<std::size_t>(std::cerr, " "));
            std::cerr << std::endl;
            return false;
321
322
323
324
325
326
        }
    }
    else
    {
        std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
                  << "arg2: time kernel (0=no, 1=yes)" << std::endl
327
328
                  << "arg3 ~ arg5: shape for 3D tensor" << std::endl
                  << "arg6 ~ arg8: axes to permute" << std::endl;
329
330
331
        return false;
    }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
332
333
    return true;
}
334
335
336
337
338
339
340
341
342
343
344
345
346

template <typename Shape>
inline std::enable_if_t<detail::is_range_v<Shape>, bool> is_valid_shape(const Shape& shape)
{
    using std::begin, std::end;
    using std::empty;
    return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; });
}

template <typename Shape, typename Indices>
inline std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool>
is_valid_indices(const Shape& shape, const Indices& indices)
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
347
348
349
350
    if(!is_valid_shape(shape))
    {
        return false;
    }
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

    using std::empty;
    if(empty(indices))
    {
        return false;
    }

    using std::size;
    if(size(shape) != size(indices))
    {
        return false;
    }

    using std::begin, std::end;

    auto dim = begin(shape);
    auto idx = begin(indices);
    for(; dim != end(shape) && idx != end(indices); ++dim, ++idx)
    {
        if(*dim <= *idx)
        {
            return false;
        }
    }

    return true;
}

template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
                            detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
                            detail::is_output_iterator_v<OutputIterator>,
                        OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
    using std::size;
387
    assert(size(shape) == size(axes));
388
389
390
391
392
393
394
395
396
397
    assert(is_valid_shape(shape) && is_valid_axes(axes));

    for(const auto axis : axes)
    {
        *iter++ = shape[axis];
    }

    return iter;
}

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
{
    detail::enlarge_array_size_t<Problem::Shape, 1> extended_shape;

    using std::begin, std::end;

    std::copy(begin(shape), end(shape), begin(extended_shape));
    extended_shape.back() = new_dim;

    return extended_shape;
}

auto extend_axes(const Problem::Axes& axes)
{
    detail::enlarge_array_size_t<Problem::Axes, 1> extended_axes;

    using std::begin, std::end;

    std::copy(begin(axes), end(axes), begin(extended_axes));
    extended_axes.back() = detail::get_array_size_v<Problem::Axes>;

    return extended_axes;
}

422
423
424
425
426
427
template <typename Shape, typename Indices>
std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
                     detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
                 bool>
advance_indices(const Shape& shape, Indices& indices)
{
428
    using std::size;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
429
430
431
432
    if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
    {
        return false;
    }
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

    bool carry = true;

    using std::rbegin, std::rend;
    auto dim = rbegin(shape);
    auto idx = rbegin(indices);
    for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx)
    {
        assert(*idx < *dim);

        *idx  = (*idx + carry);
        carry = ((*idx == *dim) ? (*idx = 0, true) : false);
    }

    return !carry;
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
450
451
452
453
454
455
template <typename Src, typename Axes, typename Functor, typename Dest>
std::enable_if_t<detail::is_random_access_range_v<Axes> && detail::is_sized_range_v<Axes> &&
                     std::is_invocable_v<Functor,
                                         std::add_lvalue_reference_t<Dest>,
                                         std::add_lvalue_reference_t<Src>>,
                 bool>
456
host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<Dest>& dest)
457
458
459
{
    const auto& shape            = src.mDesc.GetLengths();
    const auto& transposed_shape = dest.mDesc.GetLengths();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
460
461
462
463
464
465
    if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape)))
    {
        return false;
    }

    using std::size;
466
    if(!is_valid_axes(axes))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
467
468
469
    {
        return false;
    }
470
471
472
473

    static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
                  detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);

474
    if(size(shape) != size(transposed_shape))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    {
        return false;
    }

    static_assert(detail::is_random_access_range_v<ck::remove_cvref_t<decltype(shape)>> &&
                  detail::is_random_access_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
    {
        for(std::size_t idx = 0; idx < size(shape); ++idx)
        {
            if(transposed_shape[idx] != shape[axes[idx]])
            {
                return false;
            }
        }
    }
490

491
    std::vector<std::size_t> indices(size(shape), 0);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
492
493
494
495
    if(!is_valid_indices(shape, indices))
    {
        return false;
    }
496

497
    switch(size(shape))
498
    {
499
    case 3: {
500
501
502
503
504
505
506
        do
        {
            Dest output = 0;
            functor(output, src(indices[0], indices[1], indices[2]));
            dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
        } while(advance_indices(shape, indices));
    }
507
508
    break;
    case 4: {
509
510
511
512
513
514
515
        do
        {
            Dest output = 0;
            functor(output, src(indices[0], indices[1], indices[2], indices[3]));
            dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
        } while(advance_indices(shape, indices));
    }
516
517
    break;
    default: return false;
518
    }
Po-Yen, Chen's avatar
Po-Yen, Chen committed
519
520

    return true;
521
}