"driver/vscode:/vscode.git/clone" did not exist on "e4b77dcf21d7588aeacb1739c932fc1053870c1c"
put_element_fp16.cpp 3.04 KB
Newer Older
rocking's avatar
rocking committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

using XDataType     = ck::half_t;
using YDataType     = ck::half_t;
using IndexDataType = int32_t;

using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough;

using DeviceInstance =
    ck::tensor_operation::device::DevicePutElementImpl<XDataType,     // XDataType
                                                       IndexDataType, // IndexDataType
                                                       YDataType,     // YDataType
                                                       YElementwiseOp,
                                                       ck::InMemoryDataOperationEnum::Set,
                                                       1>;

int main()
{
    bool do_verification = true;
    bool time_kernel     = false;

    int N = 1024;

37
38
39
    Tensor<XDataType> x(HostTensorDescriptor{N});
    Tensor<IndexDataType> indices(HostTensorDescriptor{N});
    Tensor<YDataType> y(HostTensorDescriptor{N});
rocking's avatar
rocking committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    x.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0, 1.0});
    for(int i = 0; i < N; ++i)
        indices(i) = i;

    DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
    DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
    DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize());

    x_device_buf.ToDevice(x.mData.data());
    indices_device_buf.ToDevice(indices.mData.data());

    auto put_instance     = DeviceInstance{};
    auto put_invoker_ptr  = put_instance.MakeInvokerPointer();
    auto put_argument_ptr = put_instance.MakeArgumentPointer(
        static_cast<XDataType*>(x_device_buf.GetDeviceBuffer()),
        static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()),
        static_cast<YDataType*>(y_device_buf.GetDeviceBuffer()),
        N,
        N,
        YElementwiseOp{});

    if(!put_instance.IsSupportedArgument(put_argument_ptr.get()))
    {
        throw std::runtime_error("argument is not supported!");
    }

    float ave_time =
        put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel});

    std::cout << "perf: " << ave_time << " ms" << std::endl;

Bartlomiej Kocot's avatar
tmp  
Bartlomiej Kocot committed
72
ck::utils::CorrectnessValidator validator;
rocking's avatar
rocking committed
73
74
    if(do_verification)
    {
75
        Tensor<YDataType> y_host(HostTensorDescriptor{N});
rocking's avatar
rocking committed
76
77
78
79
80
81
82
83

        for(int i = 0; i < N; ++i)
        {
            IndexDataType idx = indices(i);
            y_host(idx)       = x(i);
        }

        y_device_buf.FromDevice(y.mData.data());
Bartlomiej Kocot's avatar
tmp  
Bartlomiej Kocot committed
84
        validator.check_err(y, y_host);
rocking's avatar
rocking committed
85
86
    }

Bartlomiej Kocot's avatar
tmp  
Bartlomiej Kocot committed
87
    return !validator.is_success();
rocking's avatar
rocking committed
88
}