"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "3854a5e11c896831cc76e3a8a50541eb970ee8a7"
kernel.hpp 1.6 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL

#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <vector>

namespace rtc {

struct kernel_argument
{
    template <class T,
              class U = std::remove_reference_t<T>,
              class   = std::enable_if_t<not std::is_base_of<kernel_argument, T>{}>>
    kernel_argument(T&& x) : size(sizeof(U)), align(alignof(U)), data(&x) // NOLINT
    {
    }
    std::size_t size;
    std::size_t align;
    void* data;
};

std::vector<char> pack_args(const std::vector<kernel_argument>& args);

struct kernel_impl;

struct kernel
{
    kernel() = default;
    kernel(const char* image, const std::string& name);
    template <class T>
    kernel(const std::vector<T>& image, const std::string& name)
        : kernel(reinterpret_cast<const char*>(image.data()), name)
    {
        static_assert(sizeof(T) == 1, "Only byte types");
    }

    void launch(hipStream_t stream,
                std::size_t global,
                std::size_t local,
                const std::vector<kernel_argument>& args) const;

    void launch(hipStream_t stream,
                std::size_t global,
                std::size_t local,
                std::vector<void*> args) const;

    template <class... Ts>
    auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
    {
        return [=](auto&&... xs) {
            launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
        };
    }

    private:
    std::shared_ptr<kernel_impl> impl;
};
} // namespace rtc

#endif