run_permute_bundle_example.inc 3.45 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

bool run_permute_bundle(const Problem& problem)
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
8
    constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType);
9
10
11
12
13
14
15

    using std::begin, std::end;

    const auto& shape = problem.shape;
    ck::remove_cvref_t<decltype(shape)> transposed_shape;
    transpose_shape(problem.shape, problem.axes, begin(transposed_shape));

Po-Yen, Chen's avatar
Po-Yen, Chen committed
16
17
    Tensor<BundleType> a(shape);
    Tensor<BundleType> b(transposed_shape);
18

Po-Yen, Chen's avatar
Po-Yen, Chen committed
19
    // initialize tensor by assigning DataType values
20
21
    using std::data, std::size;
    {
Po-Yen, Chen's avatar
Po-Yen, Chen committed
22
23
24
        auto* const elems = reinterpret_cast<DataType*>(data(a.mData));
        ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
            elems, elems + (size(a.mData) * NumElemsInBundle));
25
26
    }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
27
28
    DeviceMem a_device_buf(sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BundleType) * b.mDesc.GetElementSpaceSize());
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    a_device_buf.ToDevice(data(a.mData));

    std::array<ck::index_t, 3> a_lengths, b_lengths;
    std::array<ck::index_t, 3> a_strides, b_strides;

    const void* input = a_device_buf.GetDeviceBuffer();
    void* output      = b_device_buf.GetDeviceBuffer();

    std::copy(begin(shape), end(shape), begin(a_lengths));
    std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
    std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
    std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));

    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

    auto permute  = DevicePermuteInstance{};
    auto argument = permute.MakeArgument(
        a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});

    if(!permute.IsSupportedArgument(argument))
    {
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
    };

    auto invoker   = permute.MakeInvoker();
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

    b_device_buf.FromDevice(data(b.mData));

Po-Yen, Chen's avatar
Po-Yen, Chen committed
63
64
    // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
    const auto extended_shape = extend_shape(shape, NumElemsInBundle);
65
66
67
68
69
70
71
    const auto extended_axes  = extend_axes(problem.axes);

    ck::remove_cvref_t<decltype(extended_shape)> transposed_extended_shape;
    transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));

    Tensor<DataType> extended_a(extended_shape);
    std::memcpy(
Po-Yen, Chen's avatar
Po-Yen, Chen committed
72
        data(extended_a.mData), data(a.mData), sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
73
74
75
76
77
78
79
80
81

    Tensor<DataType> extended_host_b(transposed_extended_shape);
    if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
    {
        return false;
    }

    return ck::utils::check_err(
        ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)),
Po-Yen, Chen's avatar
Po-Yen, Chen committed
82
                                 b.mDesc.GetElementSpaceSize() * NumElemsInBundle},
83
84
85
86
87
88
89
90
91
92
93
        ck::span<const DataType>{extended_host_b.mData},
        "Error: incorrect results in output tensor",
        1e-6,
        1e-6);
}

bool run_permute_bundle_example(const Problem::Shape& default_shape,
                                const Problem::Axes& default_axes)
{
    return run_permute_bundle(Problem{default_shape, default_axes});
}