run_permute_element_example.inc 2.42 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Po-Yen, Chen's avatar
Po-Yen, Chen committed
6
bool run_permute_element(const Problem& problem)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
7
{
8
    using std::begin, std::end;
9

10
11
12
13
14
15
    const auto& shape = problem.shape;
    ck::remove_cvref_t<decltype(shape)> transposed_shape;
    transpose_shape(problem.shape, problem.axes, begin(transposed_shape));

    Tensor<ADataType> a(shape);
    Tensor<BDataType> b(transposed_shape);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
16

17
    ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(begin(a.mData), end(a.mData));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
18
19
20
21

    DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());

22
    using std::data;
23
    a_device_buf.ToDevice(data(a.mData));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
24

25
26
    std::array<ck::index_t, 3> a_lengths, b_lengths;
    std::array<ck::index_t, 3> a_strides, b_strides;
27

28
29
    const void* input = a_device_buf.GetDeviceBuffer();
    void* output      = b_device_buf.GetDeviceBuffer();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
30

31
    std::copy(begin(shape), end(shape), begin(a_lengths));
32
    std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
33
    std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
34
    std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
35

36
37
    static_assert(std::is_default_constructible_v<DevicePermuteInstance>);

38
39
40
    auto permute  = DevicePermuteInstance{};
    auto argument = permute.MakeArgument(
        a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
41

42
    if(!permute.IsSupportedArgument(argument))
Po-Yen, Chen's avatar
Po-Yen, Chen committed
43
    {
44
45
46
        std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
                  << std::endl;
        return false;
Po-Yen, Chen's avatar
Po-Yen, Chen committed
47
48
    };

49
    auto invoker   = permute.MakeInvoker();
Po-Yen, Chen's avatar
Po-Yen, Chen committed
50
    float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
51
52
53

    std::cout << "Perf: " << ave_time << " ms" << std::endl;

Po-Yen, Chen's avatar
Po-Yen, Chen committed
54
    b_device_buf.FromDevice(data(b.mData));
Po-Yen, Chen's avatar
Po-Yen, Chen committed
55

Po-Yen, Chen's avatar
Po-Yen, Chen committed
56
57
58
59
60
61
62
63
    Tensor<BDataType> host_b(transposed_shape);
    if(!host_permute(a, problem.axes, PassThrough{}, host_b))
    {
        return false;
    }

    return ck::utils::check_err(
        b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-6, 1e-6);
Po-Yen, Chen's avatar
Po-Yen, Chen committed
64
65
}

Po-Yen, Chen's avatar
Po-Yen, Chen committed
66
67
bool run_permute_element_example(const Problem::Shape& default_shape,
                                 const Problem::Axes& default_axes)
Po-Yen, Chen's avatar
Po-Yen, Chen committed
68
{
Po-Yen, Chen's avatar
Po-Yen, Chen committed
69
    return run_permute_element(Problem{default_shape, default_axes});
Po-Yen, Chen's avatar
Po-Yen, Chen committed
70
}