run_permute_bundle_example.inc 3.09 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

bool run_permute_bundle(const Problem& problem)
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
8
    constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType);
9

10
11
    const auto& input_bundle_shape = problem.shape;
    const auto& input_bundle_axes  = problem.axes;
12

13
    const auto output_bundle_shape = transpose(input_bundle_shape, input_bundle_axes);
14
15
16

    Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
    Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
17

Po-Yen, Chen's avatar
Po-Yen, Chen committed
18
    // initialize tensor by assigning DataType values
19
    ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(input_bundle_tensor.AsSpan<DataType>());
20

21
22
    DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes());
    DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes());
23

24
    using std::data;
25
    input_device_buf.ToDevice(data(input_bundle_tensor));
26
27
28
29

    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

    auto permute  = DevicePermuteInstance{};
30
31
32
33
    auto argument = permute.MakeArgument(to_array(input_bundle_shape),
                                         to_array(input_bundle_tensor.GetStrides()),
                                         to_array(output_bundle_shape),
                                         to_array(output_bundle_tensor.GetStrides()),
Po-Yen, Chen's avatar
Po-Yen, Chen committed
34
35
                                         input_device_buf.GetDeviceBuffer(),
                                         output_device_buf.GetDeviceBuffer(),
36
                                         PassThrough{});
37
38
39
40
41
42
43
44
45
46
47
48
49

    if(!permute.IsSupportedArgument(argument))
    {
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
    };

    auto invoker   = permute.MakeInvoker();
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

50
    output_device_buf.FromDevice(data(output_bundle_tensor));
51

Po-Yen, Chen's avatar
Po-Yen, Chen committed
52
    // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
53
    //               axes  from [0, 2, 1] to [0, 2, 1, 3]
54
55
    const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
    const auto input_axes  = extend_axes(input_bundle_axes);
56

Po-Yen, Chen's avatar
Po-Yen, Chen committed
57
58
    using std::begin;

59
    Tensor<DataType> input_tensor(input_shape);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
60
    ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor));
61

62
    Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
63
    if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
64
65
66
67
    {
        return false;
    }

68
69
70
71
72
    return ck::utils::check_err(output_bundle_tensor.AsSpan<const DataType>(),
                                output_tensor.AsSpan<const DataType>(),
                                "Error: incorrect results in output tensor",
                                1e-6,
                                1e-6);
73
74
}

75
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
76
{
77
    return run_permute_bundle(Problem{shape, axes});
78
}