run_permute_element_example.inc 2.85 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Po-Yen, Chen's avatar
Po-Yen, Chen committed
6
bool run_permute_element(const Problem& problem)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
7
{
8
    using std::begin, std::end;
9

10
11
    const auto& input_shape = problem.shape;
    const auto& input_axes  = problem.axes;
12

13
14
    ck::remove_cvref_t<decltype(input_shape)> output_shape;
    transpose_shape(input_shape, input_axes, begin(output_shape));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
15

16
17
    Tensor<InDataType> input_tensor(input_shape);
    Tensor<OutDataType> output_tensor(output_shape);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
18

19
20
21
22
    ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor);

    DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes());
    DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes());
Po-Yen, Chen's avatar
Po-Yen, Chen committed
23

24
    using std::data;
25
    input_device_buf.ToDevice(data(input_tensor));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
26

27
28
    std::array<ck::index_t, Problem::NumDim> input_lengths, output_lengths;
    std::array<ck::index_t, Problem::NumDim> input_strides, output_strides;
29

30
31
    const void* input_data = input_device_buf.GetDeviceBuffer();
    void* output_data      = output_device_buf.GetDeviceBuffer();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
32

33
34
35
36
    ranges::copy(input_shape, begin(input_lengths));
    ranges::copy(input_tensor.GetStrides(), begin(input_strides));
    ranges::copy(output_shape, begin(output_lengths));
    ranges::copy(output_tensor.GetStrides(), begin(output_strides));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
37

38
39
    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

40
    auto permute  = DevicePermuteInstance{};
41
42
43
44
45
46
47
    auto argument = permute.MakeArgument(input_lengths,
                                         input_strides,
                                         output_lengths,
                                         output_strides,
                                         input_data,
                                         output_data,
                                         PassThrough{});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
48

49
    if(!permute.IsSupportedArgument(argument))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
50
    {
51
52
53
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
54
55
    };

56
    auto invoker   = permute.MakeInvoker();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
57
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
58
59
60

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

61
    output_device_buf.FromDevice(data(output_tensor));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
62

63
64
    Tensor<OutDataType> output_tensor_host(output_shape);
    if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor_host))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
65
66
67
68
    {
        return false;
    }

69
70
71
72
73
    return ck::utils::check_err(output_tensor.mData,
                                output_tensor_host.mData,
                                "Error: incorrect results in output tensor",
                                1e-6,
                                1e-6);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
74
75
}

76
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
77
{
78
    return run_permute_element(Problem{shape, axes});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
79
}