kernel.cpp 4.33 KB
Newer Older
arai713's avatar
arai713 committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

Paul Fultz II's avatar
Paul Fultz II committed
4
5
6
#include <rtc/kernel.hpp>
#include <rtc/manage_ptr.hpp>
#include <rtc/hip.hpp>
arai713's avatar
arai713 committed
7
#include <stdexcept>
Paul Fultz II's avatar
Paul Fultz II committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <cassert>

// extern declare the function since hip/hip_ext.h header is broken
extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT
                                           uint32_t,
                                           uint32_t,
                                           uint32_t,
                                           uint32_t,
                                           uint32_t,
                                           uint32_t,
                                           size_t,
                                           hipStream_t,
                                           void**,
                                           void**,
                                           hipEvent_t = nullptr,
                                           hipEvent_t = nullptr,
                                           uint32_t   = 0);

namespace rtc {

std::vector<char> pack_args(const std::vector<kernel_argument>& args)
{
    std::vector<char> kernargs;
    for(auto&& arg : args)
    {
        std::size_t n = arg.size;
        const auto* p = static_cast<const char*>(arg.data);
        // Insert padding
        std::size_t padding = (arg.align - (kernargs.size() % arg.align)) % arg.align;
        kernargs.insert(kernargs.end(), padding, 0);
        kernargs.insert(kernargs.end(), p, p + n);
    }
    return kernargs;
}

using hip_module_ptr = RTC_MANAGE_PTR(hipModule_t, hipModuleUnload);

struct kernel_impl
{
    hip_module_ptr module = nullptr;
    hipFunction_t fun     = nullptr;
};

hip_module_ptr load_module(const char* image)
{
    hipModule_t raw_m;
    auto status = hipModuleLoadData(&raw_m, image);
    hip_module_ptr m{raw_m};
    if(status != hipSuccess)
        throw std::runtime_error("Failed to load module: " + hip_error(status));
    return m;
}

kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared<kernel_impl>())
{
    impl->module = load_module(image);
    auto status  = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str());
    if(hipSuccess != status)
        throw std::runtime_error("Failed to get function: " + name + ": " + hip_error(status));
}

void launch_kernel(hipFunction_t fun,
                   hipStream_t stream,
                   std::size_t global,
                   std::size_t local,
                   void* kernargs,
                   std::size_t size)
{
    assert(global > 0);
    assert(local > 0);
    void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
                      kernargs,
                      HIP_LAUNCH_PARAM_BUFFER_SIZE,
                      &size,
                      HIP_LAUNCH_PARAM_END};

    auto status = hipExtModuleLaunchKernel(fun,
                                           global,
                                           1,
                                           1,
                                           local,
                                           1,
                                           1,
                                           0,
                                           stream,
                                           nullptr,
                                           reinterpret_cast<void**>(&config),
                                           nullptr,
                                           nullptr);
    if(status != hipSuccess)
        throw std::runtime_error("Failed to launch kernel: " + hip_error(status));
}

void kernel::launch(hipStream_t stream,
                    std::size_t global,
                    std::size_t local,
                    std::vector<void*> args) const
{
    assert(impl != nullptr);
    void* kernargs   = args.data();
    std::size_t size = args.size() * sizeof(void*);

    launch_kernel(impl->fun, stream, global, local, kernargs, size);
}

void kernel::launch(hipStream_t stream,
                    std::size_t global,
                    std::size_t local,
                    const std::vector<kernel_argument>& args) const
{
    assert(impl != nullptr);
    std::vector<char> kernargs = pack_args(args);
    std::size_t size           = kernargs.size();

    launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
}

arai713's avatar
arai713 committed
125
} // namespace rtc