run_permute_bundle_example.inc 3.98 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

bool run_permute_bundle(const Problem& problem)
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
8
    constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType);
9
10
11

    using std::begin, std::end;

12
13
    const auto& input_bundle_shape = problem.shape;
    const auto& input_bundle_axes  = problem.axes;
14

15
16
17
18
19
    ck::remove_cvref_t<decltype(input_bundle_shape)> output_bundle_shape;
    transpose_shape(input_bundle_shape, input_bundle_axes, begin(output_bundle_shape));

    Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
    Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
20

Po-Yen, Chen's avatar
Po-Yen, Chen committed
21
    // initialize tensor by assigning DataType values
22
    using std::data, std::size;
23
24
25
    ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
        ck::span<DataType>{reinterpret_cast<DataType*>(data(input_bundle_tensor)),
                           input_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle});
26

27
28
    DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes());
    DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes());
29

30
    input_device_buf.ToDevice(data(input_bundle_tensor));
31

32
33
    std::array<ck::index_t, Problem::NumDim> input_bundle_lengths, output_bundle_lengths;
    std::array<ck::index_t, Problem::NumDim> input_bundle_strides, output_bundle_strides;
34

35
36
    const void* input_bundle_data = input_device_buf.GetDeviceBuffer();
    void* output_bundle_data      = output_device_buf.GetDeviceBuffer();
37

38
39
40
41
    ranges::copy(input_bundle_shape, begin(input_bundle_lengths));
    ranges::copy(input_bundle_tensor.GetStrides(), begin(input_bundle_strides));
    ranges::copy(output_bundle_shape, begin(output_bundle_lengths));
    ranges::copy(output_bundle_tensor.GetStrides(), begin(output_bundle_strides));
42
43
44
45

    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

    auto permute  = DevicePermuteInstance{};
46
47
48
49
50
51
52
    auto argument = permute.MakeArgument(input_bundle_lengths,
                                         input_bundle_strides,
                                         output_bundle_lengths,
                                         output_bundle_strides,
                                         input_bundle_data,
                                         output_bundle_data,
                                         PassThrough{});
53
54
55
56
57
58
59
60
61
62
63
64
65

    if(!permute.IsSupportedArgument(argument))
    {
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
    };

    auto invoker   = permute.MakeInvoker();
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

66
    output_device_buf.FromDevice(data(output_bundle_tensor));
67

Po-Yen, Chen's avatar
Po-Yen, Chen committed
68
    // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
69
70
    const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
    const auto input_axes  = extend_axes(input_bundle_axes);
71

72
73
    ck::remove_cvref_t<decltype(input_shape)> output_shape;
    transpose_shape(input_shape, input_axes, begin(output_shape));
74

75
76
77
78
    Tensor<DataType> input_tensor(input_shape);
    std::memcpy(data(input_tensor),
                data(input_bundle_tensor),
                input_bundle_tensor.GetElementSpaceSizeInBytes());
79

80
81
    Tensor<DataType> output_tensor(output_shape);
    if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
82
83
84
85
86
    {
        return false;
    }

    return ck::utils::check_err(
87
88
89
        ck::span<const DataType>{reinterpret_cast<DataType*>(data(output_bundle_tensor)),
                                 output_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle},
        ck::span<const DataType>{output_tensor},
90
91
92
93
94
        "Error: incorrect results in output tensor",
        1e-6,
        1e-6);
}

95
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
96
{
97
    return run_permute_bundle(Problem{shape, axes});
98
}