put_element_fp16.cpp 3.02 KB
Newer Older
rocking's avatar
rocking committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

using XDataType     = ck::half_t;
using YDataType     = ck::half_t;
using IndexDataType = int32_t;

using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough;

using DeviceInstance =
    ck::tensor_operation::device::DevicePutElementImpl<XDataType,     // XDataType
                                                       IndexDataType, // IndexDataType
                                                       YDataType,     // YDataType
                                                       YElementwiseOp,
                                                       ck::InMemoryDataOperationEnum::Set,
                                                       1>;

int main()
{
    bool do_verification = true;
    bool time_kernel     = false;

    int N = 1024;

37
38
39
    Tensor<XDataType> x(HostTensorDescriptor{N});
    Tensor<IndexDataType> indices(HostTensorDescriptor{N});
    Tensor<YDataType> y(HostTensorDescriptor{N});
rocking's avatar
rocking committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    x.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0, 1.0});
    for(int i = 0; i < N; ++i)
        indices(i) = i;

    DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
    DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
    DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize());

    x_device_buf.ToDevice(x.mData.data());
    indices_device_buf.ToDevice(indices.mData.data());

    auto put_instance     = DeviceInstance{};
    auto put_invoker_ptr  = put_instance.MakeInvokerPointer();
    auto put_argument_ptr = put_instance.MakeArgumentPointer(
        static_cast<XDataType*>(x_device_buf.GetDeviceBuffer()),
        static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()),
        static_cast<YDataType*>(y_device_buf.GetDeviceBuffer()),
        N,
        N,
        YElementwiseOp{});

    if(!put_instance.IsSupportedArgument(put_argument_ptr.get()))
    {
        throw std::runtime_error("argument is not supported!");
    }

    float ave_time =
        put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel});

    std::cout << "perf: " << ave_time << " ms" << std::endl;

    bool pass = true;
    if(do_verification)
    {
75
        Tensor<YDataType> y_host(HostTensorDescriptor{N});
rocking's avatar
rocking committed
76
77
78
79
80
81
82
83
84
85
86
87
88

        for(int i = 0; i < N; ++i)
        {
            IndexDataType idx = indices(i);
            y_host(idx)       = x(i);
        }

        y_device_buf.FromDevice(y.mData.data());
        pass = ck::utils::check_err(y, y_host);
    }

    return (pass ? 0 : 1);
}