run_permute_example.inc 4.49 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
7
8
9
#ifndef NUM_ELEMS_IN_BUNDLE
#define NUM_ELEMS_IN_BUNDLE 1
#endif

10
bool run_permute(const ExecutionConfig& config, const Problem& problem)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
11
{
12
13
14
15
#if 1 < NUM_ELEMS_IN_BUNDLE
    static_assert(std::is_same_v<ADataType, BDataType>);
#endif

16
    using std::begin, std::end;
17

18
19
20
21
22
23
    const auto& shape = problem.shape;
    ck::remove_cvref_t<decltype(shape)> transposed_shape;
    transpose_shape(problem.shape, problem.axes, begin(transposed_shape));

    Tensor<ADataType> a(shape);
    Tensor<BDataType> b(transposed_shape);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
24

25
26
27
28
29
30
    using std::data, std::size;
    {
        auto* const elems =
            reinterpret_cast<detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>*>(data(a.mData));
        std::iota(elems, elems + (size(a.mData) * NUM_ELEMS_IN_BUNDLE), 1);
    }
Po-Yen, Chen's avatar
Po-Yen, Chen committed
31
32
33
34

    DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());

35
    a_device_buf.ToDevice(data(a.mData));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
36

37
38
    std::array<ck::index_t, 3> a_lengths, b_lengths;
    std::array<ck::index_t, 3> a_strides, b_strides;
39

40
41
    const void* input = a_device_buf.GetDeviceBuffer();
    void* output      = b_device_buf.GetDeviceBuffer();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
42

43
    std::copy(begin(shape), end(shape), begin(a_lengths));
44
    std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
45
    std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
46
    std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
47

48
49
    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

50
51
52
    auto permute  = DevicePermuteInstance{};
    auto argument = permute.MakeArgument(
        a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
53

54
    if(!permute.IsSupportedArgument(argument))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
55
    {
56
57
58
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
59
60
    };

61
62
    auto invoker   = permute.MakeInvoker();
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
63
64
65
66
67

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

    if(config.do_verification)
    {
68
        b_device_buf.FromDevice(data(b.mData));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
69

70
71
72
73
74
75
76
#if NUM_ELEMS_IN_BUNDLE == 1
        Tensor<BDataType> host_b(transposed_shape);
        if(!host_permute(a, problem.axes, PassThrough{}, host_b))
        {
            return false;
        }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
77
        return ck::utils::check_err(
78
            b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10);
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#else
        // extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
        std::array<std::size_t, Problem::NumDim + 1> extended_shape;
        std::copy(begin(shape), end(shape), begin(extended_shape));
        extended_shape.back() = NUM_ELEMS_IN_BUNDLE;

        using DataType = detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>;

        Tensor<DataType> extended_a(extended_shape);
        std::memcpy(data(extended_a.mData),
                    data(a.mData),
                    sizeof(ADataType) * a.mDesc.GetElementSpaceSize());

        std::array<std::size_t, Problem::NumDim + 1> extended_axes;
        std::copy(begin(problem.axes), end(problem.axes), begin(extended_axes));
        extended_axes.back() = Problem::NumDim;

        std::array<std::size_t, Problem::NumDim + 1> transposed_extended_shape;
        transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));

        Tensor<DataType> extended_host_b(transposed_extended_shape);
        if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
        {
            return false;
        }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
105
106
107
108
109
110
111
        return ck::utils::check_err(
            ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)),
                                     b.mDesc.GetElementSpaceSize() * NUM_ELEMS_IN_BUNDLE},
            ck::span<const DataType>{extended_host_b.mData},
            "Error: incorrect results in output tensor",
            1e-10,
            1e-10);
112
#endif
Po-Yen, Chen's avatar
Po-Yen, Chen committed
113
114
115
116
117
    }

    return true;
}

118
119
120
121
bool run_permute_example(int argc,
                         char* argv[],
                         const Problem::Shape& default_shape,
                         const Problem::Axes& default_axes)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
122
123
{
    ExecutionConfig config;
124
    Problem problem(default_shape, default_axes);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
125

126
    return parse_cmd_args(argc, argv, config, problem) && run_permute(config, problem);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
127
}