run_permute_example.inc 2.38 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
bool run_permute(const ExecutionConfig& config, const Problem& problem)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
7
{
8
    using std::begin, std::end;
9

10
11
12
13
14
15
    const auto& shape = problem.shape;
    ck::remove_cvref_t<decltype(shape)> transposed_shape;
    transpose_shape(problem.shape, problem.axes, begin(transposed_shape));

    Tensor<ADataType> a(shape);
    Tensor<BDataType> b(transposed_shape);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
16

Po-Yen, Chen's avatar
Po-Yen, Chen committed
17
    std::iota(begin(a.mData), end(a.mData), 0);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
18
19
20
21
22
23

    DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());

    a_device_buf.ToDevice(a.mData.data());

24
25
26
    std::array<ck::index_t, 4> ab_lengths;
    std::array<std::array<ck::index_t, 4>, 1> a_strides, b_strides;

Po-Yen, Chen's avatar
Po-Yen, Chen committed
27
28
29
    std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
    std::array<void*, 1> output      = {b_device_buf.GetDeviceBuffer()};

30
    std::copy(begin(shape), end(shape), begin(ab_lengths));
31
32
    std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(front(a_strides)));
    std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(front(b_strides)));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
33

34
35
    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

36
    auto permute = DevicePermuteInstance{};
37
    auto argument =
38
        permute.MakeArgument(ab_lengths, a_strides, b_strides, input, output, PassThrough{});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
39

40
    if(!permute.IsSupportedArgument(argument))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
41
    {
42
43
44
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
45
46
    };

47
48
    auto invoker   = permute.MakeInvoker();
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
49
50
51
52
53

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

    if(config.do_verification)
    {
54
        Tensor<BDataType> host_b(transposed_shape);
55
        host_permute(a, problem.axes, PassThrough{}, host_b);
56
57

        b_device_buf.FromDevice(b.mData.data());
Po-Yen, Chen's avatar
Po-Yen, Chen committed
58
59

        return ck::utils::check_err(
60
            b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
61
62
63
64
65
    }

    return true;
}

66
bool run_permute_example(int argc, char* argv[])
Po-Yen, Chen's avatar
Po-Yen, Chen committed
67
68
69
70
{
    ExecutionConfig config;
    Problem problem;

71
    return parse_cmd_args(argc, argv, config, problem) && run_permute(config, problem);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
72
}